Skip to content

Commit 0cf5b15

Browse files
committed
Added function to standardise field validation.
1 parent 422d1f5 commit 0cf5b15

File tree

1 file changed

+96
-123
lines changed

1 file changed

+96
-123
lines changed

src/murfey/cli/generate_config.py

Lines changed: 96 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,21 @@ def confirm_overwrite(key: str):
7272
console.print("Invalid input. Please try again.", style="red")
7373

7474

75-
def populate_field(key: str, field: ModelField, debug: bool = False):
75+
def validate_value(value: Any, key: str, field: ModelField, debug: bool = False) -> Any:
76+
"""
77+
Helper function to validate the value of the desired field for a Pydantic model.
78+
"""
79+
validated_value, errors = field.validate(value, {}, loc=key)
80+
if errors:
81+
raise ValidationError(errors, MachineConfig)
82+
console.print(f"{key!r} validated successfully.", style="bright_green")
83+
if debug:
84+
console.print(f"Type: {type(validated_value)}", style="bright_green")
85+
console.print(f"{validated_value!r}", style="bright_green")
86+
return validated_value
87+
88+
89+
def populate_field(key: str, field: ModelField, debug: bool = False) -> Any:
7690
"""
7791
General function for inputting and validating the value of a single field against
7892
its Pydantic model.
@@ -91,17 +105,14 @@ def populate_field(key: str, field: ModelField, debug: bool = False):
91105
else answer
92106
)
93107

94-
validated_value, error = field.validate(value, {}, loc=key)
95-
if not error:
96-
console.print(f"{key!r} successfully validated", style="bright_green")
108+
# Validate and return
109+
try:
110+
return validate_value(value, key, field, debug)
111+
except ValidationError as error:
97112
if debug:
98-
console.print(
99-
f"{type(validated_value)}\n{validated_value!r}",
100-
style="bright_green",
101-
)
102-
return validated_value
103-
else:
104-
console.print("Invalid input. Please try again.", style="red")
113+
console.print(error, style="red")
114+
console.print(f"Invalid input for {key!r}. Please try again")
115+
continue
105116

106117

107118
def add_calibrations(
@@ -179,70 +190,53 @@ def get_calibration():
179190
add_calibration = ask_for_input(category="calibration setting", again=True)
180191

181192
# Validate the nested dictionary structure
182-
validated_calibrations, error = field.validate(calibrations, {}, loc=field)
183-
if not error:
184-
console.print(f"{key!r} validated successfully", style="bright_green")
193+
try:
194+
return validate_value(calibrations, key, field, debug)
195+
except ValidationError as error:
185196
if debug:
186-
console.print(
187-
f"{type(validated_calibrations)}\n{validated_calibrations!r}",
188-
style="bright_green",
189-
)
190-
return validated_calibrations
191-
else:
192-
console.print(
193-
f"Failed to validate the provided calibrations: {error}", style="red"
194-
)
197+
console.print(error, style="red")
198+
console.print(f"Failed to validate {key!r}", style="red")
195199
console.print("Returning an empty dictionary", style="red")
196200
return {}
197201

198202

199203
def add_software_packages(config: dict, debug: bool = False) -> dict[str, Any]:
200204
def get_software_name() -> str:
201-
name = (
202-
prompt(
203-
"What is the name of the software package? Supported options: 'autotem', "
204-
"'epu', 'leica', 'serialem', 'tomo'",
205-
style="yellow",
206-
)
207-
.lower()
208-
.strip()
205+
message = (
206+
"What is the name of the software package? Supported options: 'autotem', "
207+
"'epu', 'leica', 'serialem', 'tomo'"
209208
)
209+
name = prompt(message, style="yellow").lower().strip()
210210
# Validate name against "acquisition_software" field
211-
field = MachineConfig.__fields__["acquisition_software"]
212-
validated_name, error = field.validate([name], {}, loc="acquisition_software")
213-
if not error:
214-
return validated_name[0]
215-
console.print(
216-
"Invalid software name.",
217-
style="red",
218-
)
219-
if ask_for_input("software package", True) is True:
220-
return get_software_name()
221-
return ""
211+
try:
212+
field = MachineConfig.__fields__["acquisition_software"]
213+
return validate_value([name], "acquisition_software", field, False)[0]
214+
except ValidationError:
215+
console.print("Invalid software name.", style="red")
216+
if ask_for_input("software package", True) is True:
217+
return get_software_name()
218+
return ""
222219

223220
def ask_about_xml_path() -> bool:
224221
message = (
225222
"Does this software package have a settings file that needs modification? "
226223
"(y/n)"
227224
)
228-
answer = prompt(message, style="yellow").lower().strip()
229-
230-
# Validate
231-
if answer in ("y", "yes"):
232-
return True
233-
if answer in ("n", "no"):
234-
return False
235-
console.print("Invalid input.", style="red")
236-
return ask_about_xml_path()
225+
while True:
226+
answer = prompt(message, style="yellow").lower().strip()
227+
# Validate
228+
if answer in ("y", "yes"):
229+
return True
230+
if answer in ("n", "no"):
231+
return False
232+
console.print("Invalid input.", style="red")
237233

238234
def get_xml_file() -> Optional[Path]:
239-
xml_file = Path(
240-
prompt(
241-
"What is the full file path of the settings file? This should be an "
242-
"XML file.",
243-
style="yellow",
244-
)
235+
message = (
236+
"What is the full file path of the settings file? This should be an "
237+
"XML file."
245238
)
239+
xml_file = Path(prompt(message, style="yellow").strip())
246240
# Validate
247241
if xml_file.suffix:
248242
return xml_file
@@ -255,20 +249,18 @@ def get_xml_file() -> Optional[Path]:
255249
return None
256250

257251
def get_xml_tree_path() -> str:
258-
xml_tree_path = prompt(
259-
"What is the path through the XML file to the node to overwrite?",
260-
style="yellow",
261-
)
262-
# Possibly some validation checks later
252+
message = "What is the path through the XML file to the node to overwrite?"
253+
xml_tree_path = prompt(message, style="yellow").strip()
254+
# TODO: Currently no test cases for this method
263255
return xml_tree_path
264256

265257
def get_extensions_and_substrings() -> dict[str, list[str]]:
266258
def get_file_extension() -> str:
267-
extension = prompt(
259+
message = (
268260
"Please enter the extension of a file produced by this package "
269-
"that is to be analysed (e.g., '.tiff', '.eer', etc.).",
270-
style="yellow",
271-
).strip()
261+
"that is to be analysed (e.g., '.tiff', '.eer', etc.)."
262+
)
263+
extension = prompt(message, style="yellow").strip().lower()
272264
# Validate
273265
if not (extension.startswith(".") and extension.replace(".", "").isalnum()):
274266
console.print(
@@ -282,15 +274,15 @@ def get_file_extension() -> str:
282274
return extension
283275

284276
def get_file_substring() -> str:
285-
substring = prompt(
277+
message = (
286278
"Please enter a keyword that will be present in files with this "
287-
"extension. This field is case-sensitive.",
288-
style="yellow",
289-
).strip()
279+
"extension. This field is case-sensitive."
280+
)
281+
substring = prompt(message, style="yellow").strip()
290282
# Validate
291283
if bool(re.fullmatch(r"[\w\s\-]*", substring)) is False:
292284
console.print(
293-
"Invalid characters are present in this substring. Please "
285+
"Unsafe characters are present in this substring. Please "
294286
"try again. ",
295287
style="red",
296288
)
@@ -441,23 +433,13 @@ def get_file_substring() -> str:
441433
("data_required_substrings", data_required_substrings),
442434
)
443435
for field_name, value in to_validate:
444-
field = MachineConfig.__fields__[field_name]
445-
validated_value, error = field.validate(value, {}, loc=field_name)
446-
if not error:
447-
config[field_name] = validated_value
448-
console.print(
449-
f"{field_name!r} validated successfully", style="bright_green"
450-
)
436+
try:
437+
field = MachineConfig.__fields__[field_name]
438+
config[field_name] = validate_value(value, field_name, field, debug)
439+
except ValidationError as error:
451440
if debug:
452-
console.print(
453-
f"{type(validated_value)}\n{validated_value!r}",
454-
style="bright_green",
455-
)
456-
else:
457-
console.print(
458-
f"Validation failed due to the following error: {error}",
459-
style="red",
460-
)
441+
console.print(error, style="red")
442+
console.print(f"Failed to validate {field_name!r}", style="red")
461443
console.print("Please try again.", style="red")
462444
return add_software_packages(config)
463445

@@ -470,10 +452,7 @@ def add_data_directories(
470452
) -> dict[str, str]:
471453
def get_directory() -> Optional[Path]:
472454
message = "What is the full file path to the data directory you wish to add?"
473-
answer = prompt(
474-
message,
475-
style="yellow",
476-
).strip()
455+
answer = prompt(message, style="yellow").strip()
477456
# Convert "" into None
478457
if not answer:
479458
return None
@@ -520,17 +499,15 @@ def get_directory_type():
520499
continue
521500

522501
# Validate and return
523-
validated_data_directories, error = field.validate(data_directories, {}, loc=key)
524-
if not error:
525-
console.print(f"Validated {key!r} successfully", style="bright_green")
502+
try:
503+
return validate_value(data_directories, key, field, debug)
504+
except ValidationError as error:
526505
if debug:
527-
console.print(f"{type(validated_data_directories)}")
528-
console.print(f"{validated_data_directories!r}")
529-
return data_directories
530-
console.print(f"Failed to validate {key!r}", style="red")
531-
if ask_for_input(category, True) is True:
532-
return add_data_directories(key, field, debug)
533-
return {}
506+
console.print(error, style="red")
507+
console.print(f"Failed to validate {key!r}", style="red")
508+
if ask_for_input(category, True) is True:
509+
return add_data_directories(key, field, debug)
510+
return {}
534511

535512

536513
def add_create_directories(
@@ -593,17 +570,15 @@ def get_folder_alias() -> str:
593570
continue
594571

595572
# Validate and return
596-
validated_folders, errors = field.validate(folders_to_create, {}, loc=key)
597-
if not errors:
598-
console.print(f"{key!r} validated successfully", style="bright_green")
573+
try:
574+
return validate_value(folders_to_create, key, field, debug)
575+
except ValidationError as error:
599576
if debug:
600-
console.print(f"{type(validated_folders)}", style="bright_green")
601-
console.print(f"{validated_folders!r}", style="bright_green")
602-
return folders_to_create
603-
console.print(f"Failed to validate {key!r}")
604-
if ask_for_input(category, True) is True:
605-
return add_create_directories(key, field, debug)
606-
return {}
577+
console.print(error, style="red")
578+
console.print(f"Failed to validate {key!r}", style="red")
579+
if ask_for_input(category, True) is True:
580+
return add_create_directories(key, field, debug)
581+
return {}
607582

608583

609584
def add_analyse_created_directories(
@@ -626,7 +601,7 @@ def get_folder() -> str:
626601
"""
627602
Start of add_analyse_created_directories
628603
"""
629-
folders_to_create: list[str] = []
604+
folders_to_analyse: list[str] = []
630605
category = "folder for Murfey to analyse"
631606
add_folder = ask_for_input(category, False)
632607
while add_folder is True:
@@ -635,22 +610,20 @@ def get_folder() -> str:
635610
console.print("No folder name provided", style="red")
636611
add_folder = ask_for_input(category, True)
637612
continue
638-
folders_to_create.append(folder_name)
613+
folders_to_analyse.append(folder_name)
639614
add_folder = ask_for_input(category, True)
640615
continue
641616

642617
# Validate and return
643-
validated_folders, errors = field.validate(folders_to_create, {}, loc=key)
644-
if not errors:
645-
console.print(f"{key!r} validated successfully", style="bright_green")
618+
try:
619+
return sorted(validate_value(folders_to_analyse, key, field, debug))
620+
except ValidationError as error:
646621
if debug:
647-
console.print(f"{type(validated_folders)}", style="bright_green")
648-
console.print(f"{validated_folders!r}", style="bright_green")
649-
return sorted(validated_folders)
650-
console.print(f"Failed to validate {key!r}", style="red")
651-
if ask_for_input(category, True) is True:
652-
return add_analyse_created_directories(key, field, debug)
653-
return []
622+
console.print(error, style="red")
623+
console.print(f"Failed to validate {key!r}", style="red")
624+
if ask_for_input(category, True) is True:
625+
return add_analyse_created_directories(key, field, debug)
626+
return []
654627

655628

656629
def set_up_data_transfer(config: dict, debug: bool = False) -> dict:

0 commit comments

Comments
 (0)