diff --git a/dm_control/mjcf/parser.py b/dm_control/mjcf/parser.py
index 681f0293..482f751f 100644
--- a/dm_control/mjcf/parser.py
+++ b/dm_control/mjcf/parser.py
@@ -196,7 +196,11 @@ def _parse(xml_root, escape_separators=False,
with debugging.freeze_current_stack_trace():
# Recursively parse any included XML files.
to_include = []
- for include_tag in xml_root.findall('include'):
+ # Unsloth/User Fix: Use .iter() to find ALL include tags, not just root ones.
+ # We collect them first to avoid issues while modifying the tree.
+ include_tags = list(xml_root.iter('include'))
+
+ for include_tag in include_tags:
try:
# First look for the path to the included XML file in the assets dict.
path_or_xml_string = assets[include_tag.attrib['file']]
@@ -212,9 +216,15 @@ def _parse(xml_root, escape_separators=False,
resolve_references=resolve_references,
assets=assets)
to_include.append(included_mjcf)
- # We must remove tags before parsing the main XML file, since
- # these are a schema violation.
- xml_root.remove(include_tag)
+
+ # Remove the tag.
+ # Since we are iterating deeper, we need to find the parent to remove it.
+ parent = include_tag.getparent()
+ if parent is not None:
+ parent.remove(include_tag)
+ else:
+ # Fallback for root (though unlikely via iter unless it is self)
+ xml_root.remove(include_tag)
# Parse the main XML file.
try: