diff --git a/dm_control/mjcf/parser.py b/dm_control/mjcf/parser.py index 681f0293..482f751f 100644 --- a/dm_control/mjcf/parser.py +++ b/dm_control/mjcf/parser.py @@ -196,7 +196,11 @@ def _parse(xml_root, escape_separators=False, with debugging.freeze_current_stack_trace(): # Recursively parse any included XML files. to_include = [] - for include_tag in xml_root.findall('include'): + # Unsloth/User Fix: Use .iter() to find ALL include tags, not just root ones. + # We collect them first to avoid issues while modifying the tree. + include_tags = list(xml_root.iter('include')) + + for include_tag in include_tags: try: # First look for the path to the included XML file in the assets dict. path_or_xml_string = assets[include_tag.attrib['file']] @@ -212,9 +216,15 @@ def _parse(xml_root, escape_separators=False, resolve_references=resolve_references, assets=assets) to_include.append(included_mjcf) - # We must remove tags before parsing the main XML file, since - # these are a schema violation. - xml_root.remove(include_tag) + + # Remove the tag. + # Since we are iterating deeper, we need to find the parent to remove it. + parent = include_tag.getparent() + if parent is not None: + parent.remove(include_tag) + else: + # Fallback for root (though unlikely via iter unless it is self) + xml_root.remove(include_tag) # Parse the main XML file. try: