1414"""
1515Summary of non-working cases.
1616MI:
17- Any case with int scalar: A to_copy is inserted to cast the value which we don't partition.
18- This makes the constant end up outside our partition and the input to the delegate becomes
19- a to_copy placeholder. In ArmTester, the placeholder is then interpreted as an input.
20- Potential fix: partition int -> float to_copy-ops in ArmBackend.
21- # MLETORCH-407
2217 Op(scalar, tensor):
2318 One issue is that lift_constant_tensor_pass looks for a fake_tensor in the meta of the first
2419 node which does not work the first node is a scalar.
2722 somewhere in _transform in the to_edge step. This makes ArmPartitioner miss tagging the
2823 data in tag_constant_data.
2924 # MLETORCH-408
30-
31- BI:
32- sub(Scalar, Tensor) becomes rsub, which either fails since the scalar does not become an attribute
33- in scalars_to_attribute_pass, or, if added to targeted_ops in that pass, fails since rsub expects a
34- Scalar.
35- Potential fix: Create pass to convert rsub.Scalar to sub.Tensor
25+ Sub or inplace-sub with an integer input.
3626"""
3727
3828
3929class TestScalars (unittest .TestCase ):
40- """Tests various scalar cases for for """
30+ """Tests various scalar cases"""
4131
4232 class Add (torch .nn .Module ):
4333 def forward (self , x , y ):
@@ -133,13 +123,10 @@ def forward(self, x):
133123 scalar = dtype [1 ]
134124 tensor_scalar_tests .append ((test_name + "_ts" , op [1 ], tensor , scalar ))
135125
136- # Don't add (scalar, tensor) test case for inplace and .Scalar ops.
137- if op [0 ][- 1 ] == "_" or op [ 0 ][ - 6 :] == "Scalar" :
126+ # Don't add (scalar, tensor) test case for .Scalar ops.
127+ if op [0 ][- 6 :] == "Scalar" :
138128 continue
139129
140- # sub(scalar, tensor) does not work in any case.
141- if op [0 ][0 :3 ] == "Sub" :
142- continue
143130 tensor_scalar_tests .append ((test_name + "_st" , op [1 ], scalar , tensor ))
144131
145132 tensor_const_tests = []
@@ -182,8 +169,8 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple):
182169 def test_MI (self , test_name : str , op : torch .nn .Module , x , y ):
183170 expected_exception = None
184171 if any (token in test_name for token in ("Sub_int" , "Sub__int" )):
185- expected_exception = ( AssertionError , ValueError )
186- elif test_name .endswith ("_st" ):
172+ expected_exception = AssertionError
173+ if test_name .endswith ("_st" ):
187174 expected_exception = AttributeError
188175
189176 if expected_exception :
@@ -204,5 +191,13 @@ def test_MI_const(self, test_name: str, op: torch.nn.Module, x):
204191 def test_BI (self , test_name : str , op : torch .nn .Module , x , y ):
205192 self ._test_add_tosa_BI_pipeline (op , (x , y ))
206193
194+ # op(Scalar float, tensor) works if the scalar is constant.
195+ @parameterized .expand (tensor_const_tests )
196+ def test_BI_const (self , test_name : str , op : torch .nn .Module , x ):
197+ self ._test_add_tosa_BI_pipeline (op , (x ,))
198+
207199 def test_shift_sub_inplace_tosa_MI (self ):
208200 self ._test_add_tosa_MI_pipeline (self .ShiftInplaceSub (), (torch .IntTensor (5 ),))
201+
202+ def test_shift_sub_inplace_tosa_BI (self ):
203+ self ._test_add_tosa_BI_pipeline (self .ShiftInplaceSub (), (torch .IntTensor (5 ),))
0 commit comments