diff --git a/robosuite/models/base.py b/robosuite/models/base.py index 52e6f98d47..ae23f33b00 100644 --- a/robosuite/models/base.py +++ b/robosuite/models/base.py @@ -205,6 +205,26 @@ def get_element_names(self, root, element_type): names += self.get_element_names(child, element_type) return names + # @staticmethod + # def _get_default_classes(default): + # """ + # Utility method to convert all default tags into a nested dictionary of values -- this will be used to replace + # all elements' class tags inline with the appropriate defaults if not specified. + + # Args: + # default (ET.Element): Nested default tag XML root. + + # Returns: + # dict: Nested dictionary, where each default class name is mapped to its own dict mapping element tag names + # (e.g.: geom, site, etc.) to the set of default attributes for that tag type + # """ + # # Create nested dict to return + # default_dic = {} + # # Parse the default tag accordingly + # for cls in default: + # default_dic[cls.get("class")] = {child.tag: child for child in cls} + # return default_dic + @staticmethod def _get_default_classes(default): """ @@ -218,13 +238,50 @@ def _get_default_classes(default): dict: Nested dictionary, where each default class name is mapped to its own dict mapping element tag names (e.g.: geom, site, etc.) to the set of default attributes for that tag type """ - # Create nested dict to return - default_dic = {} - # Parse the default tag accordingly + def _parse(cur_element, record, parent_name=None): + cur_record = {'parent': parent_name, 'data': {}} + record[cur_element.get('class')] = cur_record + for child_element in cur_element: + if child_element.tag == 'default': + # cur_record['parent'] = cur_element.get('class') + _parse(child_element, + record, + parent_name=[cur_element.get('class')] if parent_name + is None else parent_name + [cur_element.get('class')]) + else: + cur_record['data'][child_element.tag] = child_element + + record = {} for cls in default: - default_dic[cls.get("class")] = {child.tag: child for child in cls} + _parse(cls, record) + # Update + for cls_name, cls_record in record.items(): + if cls_record['parent'] is None: + continue + for parent_cls_name in cls_record['parent']: + for element_name, element in cls_record['data'].items(): + for p_element_name, p_element in record[parent_cls_name][ + 'data'].items(): + if element.tag == p_element.tag: + for attr_name, attr_value in p_element.items(): + if attr_name not in element.attrib: + element.set(attr_name, attr_value) + # Update child element + for p_element_name, p_element in record[parent_cls_name]['data'].items(): + if p_element.tag not in record[cls_name]['data']: + record[cls_name]['data'][p_element.tag] = p_element + print(f"record: {record}") + # Build default_dic + default_dic = {} + for cls_name, cls_record in record.items(): + default_dic[cls_name] = { + element.tag: element + for element in cls_record['data'].values() + } + print(f"default_dic: {default_dic}") return default_dic + def _replace_defaults_inline(self, default_dic, root=None): """ Utility method to replace all default class attributes recursively in the XML tree starting from @root