Skip to content

Commit 2e440fc

Browse files
committed
lint
Signed-off-by: xadupre <[email protected]>
1 parent 4e13ee8 commit 2e440fc

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

tf2onnx/convert.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -438,11 +438,14 @@ def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom
438438
if hasattr(model, "inputs"):
439439
model_input = model.inputs
440440
elif hasattr(model, "input_dtype") and hasattr(model, "_build_shapes_dict"):
441-
if len(model._build_shapes_dict) == 1:
442-
shape = list(model._build_shapes_dict.values())[0]
441+
if len(model._build_shapes_dict) == 1: # noqa: W0212
442+
shape = list(model._build_shapes_dict.values())[0] # noqa: W0212
443443
model_input = [tf.Variable(tf.zeros(shape, dtype=model.input_dtype), name="input")]
444444
else:
445-
raise RuntimeError(f"Not implemented yet with input_dtype={model.input_dtype} and model._build_shapes_dict={model._build_shapes_dict}")
445+
raise RuntimeError(
446+
f"Not implemented yet with input_dtype={model.input_dtype} "
447+
f"and model._build_shapes_dict={model._build_shapes_dict}" # noqa: W0212
448+
)
446449
else:
447450
if not hasattr(model, "inputs_spec"):
448451
raise RuntimeError("You may set attribute 'inputs_spec' with your inputs (model.input_specs = ...)")
@@ -493,15 +496,19 @@ def _get_name(t, i):
493496
try:
494497
return t.name
495498
except AttributeError:
496-
return f"output:{i}"
499+
return f"output:{i}"
497500

498501
for out in [_get_name(t, i) for i, t in enumerate(model_output)]:
499502
if out in reverse_lookup:
500503
valid_names.append(reverse_lookup[out])
501504
else:
502505
print(f"Warning: Output name '{out}' not found in reverse_lookup.")
503506
# Fallback: verwende TensorFlow-Ausgangsnamen direkt
504-
valid_names = [_get_name(t, i) for i, t in enumerate(concrete_func.outputs) if t.dtype != tf.dtypes.resource]
507+
valid_names = [
508+
_get_name(t, i)
509+
for i, t in enumerate(concrete_func.outputs)
510+
if t.dtype != tf.dtypes.resource
511+
]
505512
break
506513
output_names = valid_names
507514

0 commit comments

Comments
 (0)