Skip to content

Commit d71399d

Browse files
committed
improve test coverage
Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent 1ac7868 commit d71399d

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

tests/unit/onnx/autocast/test_nodeclassifier.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,16 @@ def test_node_classifier_force_include(test_model):
369369
assert "add_node" in fp16_nodes
370370
assert "add_node" not in fp32_nodes
371371
assert not set(fp16_nodes).intersection(set(fp32_nodes))
372+
373+
# Set init_max low so both nodes would normally be excluded (kept in FP32)
374+
# Force op type Mul to low precision, despite exceeding init_max
375+
classifier3 = NodeClassifier(
376+
model=test_model,
377+
node_to_init_map=node_to_init_map,
378+
init_max=1.0,
379+
op_types_to_include=["Mul"],
380+
)
381+
fp16_nodes, fp32_nodes = classifier3.run()
382+
assert "mul_node" in fp16_nodes
383+
assert "add_node" in fp32_nodes
384+
assert not set(fp16_nodes).intersection(set(fp32_nodes))

0 commit comments

Comments
 (0)