Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 61 additions & 4 deletions robosuite/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,26 @@ def get_element_names(self, root, element_type):
names += self.get_element_names(child, element_type)
return names

# @staticmethod
# def _get_default_classes(default):
# """
# Utility method to convert all default tags into a nested dictionary of values -- this will be used to replace
# all elements' class tags inline with the appropriate defaults if not specified.

# Args:
# default (ET.Element): Nested default tag XML root.

# Returns:
# dict: Nested dictionary, where each default class name is mapped to its own dict mapping element tag names
# (e.g.: geom, site, etc.) to the set of default attributes for that tag type
# """
# # Create nested dict to return
# default_dic = {}
# # Parse the default tag accordingly
# for cls in default:
# default_dic[cls.get("class")] = {child.tag: child for child in cls}
# return default_dic

@staticmethod
def _get_default_classes(default):
"""
Expand All @@ -218,13 +238,50 @@ def _get_default_classes(default):
dict: Nested dictionary, where each default class name is mapped to its own dict mapping element tag names
(e.g.: geom, site, etc.) to the set of default attributes for that tag type
"""
# Create nested dict to return
default_dic = {}
# Parse the default tag accordingly
def _parse(cur_element, record, parent_name=None):
cur_record = {'parent': parent_name, 'data': {}}
record[cur_element.get('class')] = cur_record
for child_element in cur_element:
if child_element.tag == 'default':
# cur_record['parent'] = cur_element.get('class')
_parse(child_element,
record,
parent_name=[cur_element.get('class')] if parent_name
is None else parent_name + [cur_element.get('class')])
else:
cur_record['data'][child_element.tag] = child_element

record = {}
for cls in default:
default_dic[cls.get("class")] = {child.tag: child for child in cls}
_parse(cls, record)
# Update
for cls_name, cls_record in record.items():
if cls_record['parent'] is None:
continue
for parent_cls_name in cls_record['parent']:
for element_name, element in cls_record['data'].items():
for p_element_name, p_element in record[parent_cls_name][
'data'].items():
if element.tag == p_element.tag:
for attr_name, attr_value in p_element.items():
if attr_name not in element.attrib:
element.set(attr_name, attr_value)
# Update child element
for p_element_name, p_element in record[parent_cls_name]['data'].items():
if p_element.tag not in record[cls_name]['data']:
record[cls_name]['data'][p_element.tag] = p_element
print(f"record: {record}")
# Build default_dic
default_dic = {}
for cls_name, cls_record in record.items():
default_dic[cls_name] = {
element.tag: element
for element in cls_record['data'].values()
}
print(f"default_dic: {default_dic}")
return default_dic


def _replace_defaults_inline(self, default_dic, root=None):
"""
Utility method to replace all default class attributes recursively in the XML tree starting from @root
Expand Down
Loading