Skip to content

Commit 239596d

Browse files
committed
Fix bug in CLI with calling a factory-fn inside a list
Signed-off-by: Marc Romeyn <[email protected]>
1 parent f3c8e99 commit 239596d

File tree

1 file changed

+72
-12
lines changed

1 file changed

+72
-12
lines changed

nemo_run/cli/cli_parser.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -796,14 +796,48 @@ def parse_list(self, value: str, annotation: Type[List]) -> List:
796796
Raises:
797797
ListParseError: If the value cannot be parsed as a list.
798798
"""
799-
try:
800-
parsed = ast.literal_eval(value)
801-
if not isinstance(parsed, list):
802-
raise ValueError("Not a list")
803-
elem_type = get_args(annotation)[0]
804-
return [self.parse(str(item), elem_type) for item in parsed]
805-
except Exception as e:
806-
raise ListParseError(value, List, f"Invalid list: {str(e)}")
799+
# Remove outer brackets and whitespace
800+
if not (value.startswith("[") and value.endswith("]")):
801+
raise ListParseError(value, List, "List must be enclosed in square brackets")
802+
803+
inner = value.strip("[] ")
804+
elements = []
805+
current = ""
806+
nesting = 0
807+
808+
# Parse character by character to handle nested structures
809+
for char in inner:
810+
if char == "," and nesting == 0:
811+
if current.strip():
812+
elements.append(current.strip())
813+
current = ""
814+
else:
815+
if char in "([{":
816+
nesting += 1
817+
elif char in ")]}":
818+
nesting -= 1
819+
current += char
820+
821+
# Add the last element if it exists
822+
if current.strip():
823+
elements.append(current.strip())
824+
825+
# Process each element - try literal_eval first, fallback to string
826+
parsed_elements = []
827+
for element in elements:
828+
try:
829+
parsed_element = ast.literal_eval(element)
830+
except Exception:
831+
parsed_element = element
832+
parsed_elements.append(parsed_element)
833+
834+
print(parsed_elements)
835+
import pdb
836+
837+
pdb.set_trace()
838+
839+
elem_type = get_args(annotation)[0]
840+
return [self.parse(str(item), elem_type) for item in parsed_elements]
807841

808842
def parse_dict(self, value: str, annotation: Type[Dict]) -> Dict:
809843
"""Parse a string value into a dictionary of the specified key-value types.
@@ -1129,6 +1163,10 @@ def dummy_model_config():
11291163
raise UndefinedVariableError(
11301164
f"Cannot use '{op.value}' on undefined variable", arg, {"key": key}
11311165
)
1166+
# a = parser.apply_operation(op, getattr(nested, arg_name), parsed_value)
1167+
import pdb
1168+
1169+
pdb.set_trace()
11321170
setattr(
11331171
nested,
11341172
arg_name,
@@ -1253,13 +1291,35 @@ def parse_single_factory(factory_str):
12531291
return factory_fn()
12541292

12551293
# Check if the value is a list
1256-
list_match = re.match(r"^\s*\[(.*)\]\s*$", value)
1257-
if list_match:
1294+
if value.startswith("[") and value.endswith("]"):
12581295
# Check if arg_type is List[T], if so get T
12591296
if get_origin(arg_type) is list:
12601297
arg_type = get_args(arg_type)[0]
1261-
items = re.findall(r"([^,]+(?:\([^)]*\))?)", list_match.group(1))
1262-
return [parse_single_factory(item.strip()) for item in items]
1298+
1299+
# Parse list with nested structure handling
1300+
inner = value.strip("[] ")
1301+
elements = []
1302+
current = ""
1303+
nesting = 0
1304+
1305+
# Parse character by character to handle nested structures
1306+
for char in inner:
1307+
if char == "," and nesting == 0:
1308+
if current.strip():
1309+
elements.append(current.strip())
1310+
current = ""
1311+
else:
1312+
if char in "([{":
1313+
nesting += 1
1314+
elif char in ")]}":
1315+
nesting -= 1
1316+
current += char
1317+
1318+
# Add the last element if it exists
1319+
if current.strip():
1320+
elements.append(current.strip())
1321+
1322+
return [parse_single_factory(item.strip()) for item in elements]
12631323

12641324
return parse_single_factory(value)
12651325

0 commit comments

Comments
 (0)