@@ -27,17 +27,10 @@ def forward(self, weights: torch.Tensor, indices: torch.Tensor):
2727 return torch .embedding (weights , indices )
2828
2929
30- class ExpandEmbedding (Embedding ):
31- example_inputs = (torch .randn (10 , 3 ), torch .tensor ([[1 , 2 , 3 ]], dtype = torch .int32 ))
32-
33- def forward (self , weights : torch .Tensor , indices : torch .Tensor ):
34- return torch .embedding (weights , indices .expand (2 , 3 ))
35-
36-
37- input_params = Tuple [torch .Tensor , torch .Tensor ]
30+ input_params = Tuple [torch .Tensor , torch .Tensor , torch .dtype ]
3831
3932
40- test_input : dict [str , input_params ] = {
33+ test_input : dict [input_params ] = {
4134 "test_1" : (
4235 torch .randn (10 , 3 ),
4336 torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = torch .int32 ),
@@ -96,21 +89,6 @@ def test_embedding_tosa_INT(test_input: input_params):
9689 pipeline .run ()
9790
9891
99- def test_expand_embedding_tosa_INT ():
100- op = ExpandEmbedding ()
101- pipeline = TosaPipelineINT (
102- op ,
103- ExpandEmbedding .example_inputs ,
104- ExpandEmbedding .aten_op ,
105- ExpandEmbedding .exir_op ,
106- use_to_edge_transform_and_lower = True ,
107- )
108- pipeline .pop_stage ("check.aten" )
109- pipeline .pop_stage ("check_count.exir" )
110-
111- pipeline .run ()
112-
113-
11492@pytest .mark .skip ("reason=MLETORCH-1274 Improve data type checks during partitioning" )
11593@common .parametrize ("test_input" , test_input )
11694@common .SkipIfNoModelConverter
0 commit comments