|
1 | 1 | import errno |
2 | 2 | import glob |
| 3 | +import inspect |
3 | 4 | import itertools |
4 | 5 | import json |
5 | 6 | import logging |
@@ -57,27 +58,45 @@ def dict_pop_lazy(x, key, *args, **kwargs): |
57 | 58 | return default(y, *args, **kwargs) |
58 | 59 |
|
59 | 60 |
|
60 | | -def add_if_new(d, key, x, kwargs, model_name, in_keys, other_args=None): |
61 | | - # if key is list then assume model returns multiple args |
| 61 | +def try_use_model(model, model_name, input_vals): |
| 62 | + try: |
| 63 | + return default(None, model, input_vals, is_none) |
| 64 | + except TypeError as e: |
| 65 | + add_error_message( |
| 66 | + e, |
| 67 | + f"\n{model_name}.forward() signature is {inspect.signature(model.forward)}", |
| 68 | + ) |
| 69 | + raise |
| 70 | + |
| 71 | + |
| 72 | +def assign_to_output(d, key, x, new_x, condition): |
| 73 | + if len(x) > 1: |
| 74 | + if not is_list_or_tuple(new_x) or len(new_x) != len(x): |
| 75 | + raise TypeError( |
| 76 | + "if input x and key are lists, then output of model must be a list of the same length" |
| 77 | + ) |
| 78 | + for i in range(len(x)): |
| 79 | + if condition(x[i]): |
| 80 | + d[key[i]] = new_x[i] |
| 81 | + else: |
| 82 | + d[key[0]] = new_x |
| 83 | + |
| 84 | + |
| 85 | +def add_if_new(d, key, x, kwargs, model_name, in_keys, other_args=None, logger=None): |
| 86 | + other_args = default(other_args, {}) |
| 87 | + if logger: |
| 88 | + logger(f"Getting output: {key}") |
| 89 | + logger( |
| 90 | + f"Using model {model_name} with inputs: {', '.join(in_keys + list(other_args.keys()))}" |
| 91 | + ) |
62 | 92 | if not is_list_or_tuple(key) or not is_list_or_tuple(x): |
63 | 93 | raise TypeError("key and x must both be a list or tuple") |
64 | 94 | condition = is_none |
65 | 95 | if any(condition(y) for y in x): |
66 | 96 | model = kwargs[model_name] |
67 | | - input_vals = [kwargs[k] for k in in_keys] |
68 | | - if other_args is not None: |
69 | | - input_vals += other_args |
70 | | - new_x = default(None, model, input_vals, is_none) |
71 | | - if len(x) > 1: |
72 | | - if not is_list_or_tuple(new_x) or len(new_x) != len(x): |
73 | | - raise TypeError( |
74 | | - "if input x and key are lists, then output of model must be a list of the same length" |
75 | | - ) |
76 | | - for i in range(len(x)): |
77 | | - if condition(x[i]): |
78 | | - d[key[i]] = new_x[i] |
79 | | - else: |
80 | | - d[key[0]] = new_x |
| 97 | + input_vals = [kwargs[k] for k in in_keys] + list(other_args.values()) |
| 98 | + new_x = try_use_model(model, model_name, input_vals) |
| 99 | + assign_to_output(d, key, x, new_x, condition) |
81 | 100 |
|
82 | 101 |
|
83 | 102 | def class_default(cls, x, default): |
@@ -399,9 +418,19 @@ def attrs_of_type(cls, obj): |
399 | 418 | return {k: v for k, v in attrs.items() if isinstance(v, obj)} |
400 | 419 |
|
401 | 420 |
|
402 | | -def append_error_message(e, msg): |
| 421 | +# https://stackoverflow.com/a/70114007/16941290 |
| 422 | +class ErrorMsgWithNewLines(str): |
| 423 | + def __repr__(self): |
| 424 | + return str(self) |
| 425 | + |
| 426 | + |
| 427 | +def add_error_message(e, msg, prepend=False): |
403 | 428 | if len(e.args) >= 1: |
404 | | - e.args = (e.args[0] + msg,) + e.args[1:] |
| 429 | + if prepend: |
| 430 | + x = (ErrorMsgWithNewLines(msg + e.args[0]),) |
| 431 | + else: |
| 432 | + x = (ErrorMsgWithNewLines(e.args[0] + msg),) |
| 433 | + e.args = x + e.args[1:] |
405 | 434 |
|
406 | 435 |
|
407 | 436 | def requires_grad(x, does=True): |
|
0 commit comments