@@ -116,7 +116,7 @@ def tosa_support_factory(
116
116
117
117
# Negative checks: Remove nodes from partitioning
118
118
negative_checks : list [OperatorSupportBase ] = [
119
- CheckInt64Inputs (exported_program , reporter ),
119
+ CheckInt64InputsAndOutputs (exported_program , reporter ),
120
120
CheckFloat64Inputs (exported_program , reporter ),
121
121
RankCheck (reporter , max_rank = 5 ),
122
122
* [
@@ -454,7 +454,18 @@ def is_node_supported(
454
454
return True
455
455
456
456
457
- class CheckInt64Inputs (OperatorSupportBase ):
457
+ class CheckInt64InputsAndOutputs (OperatorSupportBase ):
458
+ """TOSA does not support int64 tensors so in general, ops with int64 inputs or outputs should not be partitioned.
459
+ There are however some exceptions:
460
+ - Nodes with int64 output can be partitioned if they are constant, within int32,
461
+ and all users cast to something else. In this case, the int64 tensor can safely be cast to int32 AOT.
462
+ - Nodes with int64 output can be partitioned if all users are getitem with non-int64 output.
463
+ In this case, there are multiple outputs and the int64 ones are not used.
464
+ - Nodes with int64 inputs can be partitioned if the inputs are constant placeholders, or constant
465
+ ops fulfilling the criteria above.
466
+ Note that we don't check placeholders here, they are partitioned based on whether their users are partitioned
467
+ or not.
468
+ """
458
469
459
470
def __init__ (
460
471
self , exported_program : ExportedProgram , reporter : WhyNoPartitionReporter
@@ -465,27 +476,85 @@ def __init__(
465
476
if spec .kind == InputKind .USER_INPUT
466
477
]
467
478
self .reporter = reporter
479
+ self .int32_min = torch .iinfo (torch .int32 ).min
480
+ self .int32_max = torch .iinfo (torch .int32 ).max
468
481
super ().__init__ ()
469
482
483
+ def inside_int32_bounds (self , node : torch .fx .Node ) -> bool :
484
+ """Node is assumed to be call_function with int64 output."""
485
+ if isinstance (node .target , str ):
486
+ return False
487
+ data = node .target (* node .args , ** node .kwargs )
488
+ min_val , max_val = int (torch .min (data )), int (torch .max (data ))
489
+ return min_val >= self .int32_min and max_val <= self .int32_max
490
+
470
491
def is_node_supported (
471
492
self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
472
493
) -> bool :
473
494
495
+ vals = node .meta ["val" ]
496
+ tensor_list = vals if isinstance (vals , (list , tuple )) else [vals ]
497
+
498
+ any_int64 = any (tensor .dtype == torch .int64 for tensor in tensor_list )
499
+ # Don't partition nodes with int64 output...
500
+ if any_int64 :
501
+ # ... Except for constant ops that are directly cast to something non-int64.
502
+ # This could be an explicit cast, or something like a less than that outputs a different dtype than the input.
503
+ users_output_non_int64 = all (
504
+ get_first_fake_tensor (output_node ).dtype != torch .int64
505
+ for output_node in node .users
506
+ )
507
+ if (
508
+ node .target in ComputeConstantOpsAOT .targeted_ops
509
+ and users_output_non_int64
510
+ ):
511
+ if not self .inside_int32_bounds (node ):
512
+ self .reporter .report_reject (
513
+ node , "Constant node outside int32 range."
514
+ )
515
+ return False
516
+ # Will never have input nodes, safe to return True
517
+ return True
518
+
519
+ # ... Or ops with multiple outputs where only non-int64 are used.
520
+ users_are_getitem = all (
521
+ user .target == operator .getitem for user in node .users
522
+ )
523
+ if users_are_getitem and users_output_non_int64 :
524
+ # Passed output check, go to input check.
525
+ pass
526
+ else :
527
+ self .reporter .report_reject (
528
+ node , "Non-constant node with int64 output."
529
+ )
530
+ return False
531
+
532
+ # Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned.
533
+ # If it is not partitioned, the partition will get an int64 input and fail.
474
534
for input_node in node .all_input_nodes :
475
- # We can cast constant placeholders and constant ops AOT, such int64 are ok.
476
- # Otherwise, don't partition if one or more inputs are int64.
535
+ tensor_in = get_first_fake_tensor (input_node )
536
+ if tensor_in .dtype != torch .int64 :
537
+ continue
538
+ # Constant placeholder
477
539
if (
478
- input_node .name in self . input_names
479
- or not input_node .op == "placeholder"
540
+ input_node .op != "call_function"
541
+ and input_node .name not in self . input_names
480
542
):
481
- tensor = get_first_fake_tensor (input_node )
482
- if tensor .dtype == torch .int64 :
483
- if input_node .target not in ComputeConstantOpsAOT .targeted_ops :
484
- self .reporter .report_reject (
485
- node ,
486
- f"Had int64 input { input_node .name } that couldn't be handled." ,
487
- )
488
- return False
543
+ continue
544
+ # Constant operator
545
+ if input_node .op == "call_function" :
546
+ if input_node .target in ComputeConstantOpsAOT .targeted_ops :
547
+ # This is not perfect since the input_node can still be rejected by other checks but
548
+ # this should cover the majority of cases.
549
+ if self .is_node_supported (
550
+ None , input_node # type: ignore[arg-type] #(we don't use 'submodules')
551
+ ):
552
+ continue
553
+ self .reporter .report_reject (
554
+ node , f"Non-constant int64 input { input_node .name } "
555
+ )
556
+ return False
557
+
489
558
return True
490
559
491
560
0 commit comments