We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1ac7868 commit d71399dCopy full SHA for d71399d
tests/unit/onnx/autocast/test_nodeclassifier.py
@@ -369,3 +369,16 @@ def test_node_classifier_force_include(test_model):
369
assert "add_node" in fp16_nodes
370
assert "add_node" not in fp32_nodes
371
assert not set(fp16_nodes).intersection(set(fp32_nodes))
372
+
373
+ # Set init_max low so both nodes would normally be excluded (kept in FP32)
374
+ # Force op type Mul to low precision, despite exceeding init_max
375
+ classifier3 = NodeClassifier(
376
+ model=test_model,
377
+ node_to_init_map=node_to_init_map,
378
+ init_max=1.0,
379
+ op_types_to_include=["Mul"],
380
+ )
381
+ fp16_nodes, fp32_nodes = classifier3.run()
382
+ assert "mul_node" in fp16_nodes
383
+ assert "add_node" in fp32_nodes
384
+ assert not set(fp16_nodes).intersection(set(fp32_nodes))
0 commit comments