@@ -563,18 +563,18 @@ def split(node: ir.Node, op, _):
563563 # Option A: Sizes per split
564564 if len (node .inputs ) == 2 :
565565 # Skip non-constant splits
566- if (_split := ir .convenience .get_const_tensor (node .inputs [1 ])) is None :
566+ if (split_ := ir .convenience .get_const_tensor (node .inputs [1 ])) is None :
567567 return None
568568 # Numpy expects splits as starting indices for each section
569- _split = np .cumsum (_split .numpy ()[:- 1 ])
569+ split_ = np .cumsum (split_ .numpy ()[:- 1 ])
570570
571571 # Option B: Number of (even) splits
572572 elif (num_outputs := node .attributes .get ("num_outputs" )) is not None :
573573 # Numpy accepts single integer of (even) splits as well
574- _split = num_outputs .as_int ()
574+ split_ = num_outputs .as_int ()
575575
576576 # Unable to determine split configuration, skip optimization
577- if _split is None :
577+ if split_ is None :
578578 return None
579579
580580 # Default split axis is 0, according to ONNX operators reference:
@@ -583,7 +583,7 @@ def split(node: ir.Node, op, _):
583583 axis = ir .Attr ("axis" , ir .AttributeType .INT , 0 )
584584
585585 # Split constant tensor and wrap a list of Constant operators
586- splits = np .array_split (x .numpy (), _split , axis .as_int ())
586+ splits = np .array_split (x .numpy (), split_ , axis .as_int ())
587587 return [op .Constant (value = ir .tensor (x )) for x in splits ]
588588
589589
0 commit comments