@@ -4218,3 +4218,113 @@ func.func @torch.aten.convolution$si8(%arg0: !torch.vtensor<[2,2,6,6],si8>, %arg
4218
4218
%4 = torch.aten.convolution %arg0 , %arg1 , %arg2 , %0 , %1 , %2 , %false , %3 , %int1 : !torch.vtensor <[2 ,2 ,6 ,6 ],si8 >, !torch.vtensor <[8 ,2 ,3 ,3 ],si8 >, !torch.vtensor <[8 ],si32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[2 ,8 ,4 ,4 ],si32 >
4219
4219
return %4 : !torch.vtensor <[2 ,8 ,4 ,4 ],si32 >
4220
4220
}
4221
+ // CHECK-LABEL: func.func @torch.aten.mm$f32(
4222
+ // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[1,22],f32>,
4223
+ // CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[22,10],f32>) -> !torch.vtensor<[1,10],f32> {
4224
+ // CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[22,10],f32> -> tensor<22x10xf32>
4225
+ // CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[1,22],f32> -> tensor<1x22xf32>
4226
+ // CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 22]> : tensor<3xindex>} : () -> !tosa.shape<3>
4227
+ // CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<1x22xf32>, !tosa.shape<3>) -> tensor<1x1x22xf32>
4228
+ // CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 22, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4229
+ // CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TENSOR]], %[[WTS_SHAPE]] : (tensor<22x10xf32>, !tosa.shape<3>) -> tensor<1x22x10xf32>
4230
+ // CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4231
+ // CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4232
+ // CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x1x22xf32>, tensor<1x22x10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x10xf32>
4233
+ // CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
4234
+ // CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x1x10xf32>, !tosa.shape<2>) -> tensor<1x10xf32>
4235
+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[RES_RESHAPE]] : tensor<1x10xf32> -> !torch.vtensor<[1,10],f32>
4236
+ // CHECK: return %[[RES]]
4237
+ func.func @torch.aten.mm$f32 (%arg0: !torch.vtensor <[1 ,22 ],f32 >, %arg1: !torch.vtensor <[22 ,10 ],f32 >) -> !torch.vtensor <[1 ,10 ],f32 > {
4238
+ %0 = torch.aten.mm %arg0 , %arg1 : !torch.vtensor <[1 ,22 ],f32 >, !torch.vtensor <[22 ,10 ],f32 > -> !torch.vtensor <[1 ,10 ],f32 >
4239
+ return %0 : !torch.vtensor <[1 ,10 ],f32 >
4240
+ }
4241
+
4242
+ // -----
4243
+ // CHECK-LABEL: func.func @torch.aten.mm$si8
4244
+ // CHECK: tosa.matmul
4245
+ // CHECK-SAME: (tensor<1x1x22xi8>, tensor<1x22x10xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x10xi32>
4246
+ // CHECK-NOT: torch.aten.mm
4247
+ // CHECK: tosa.cast
4248
+ // CHECK-SAME: (tensor<1x10xi32>) -> tensor<1x10xi8>
4249
+ func.func @torch.aten.mm$si8 (%arg0: !torch.vtensor <[1 ,22 ],si8 >, %arg1: !torch.vtensor <[22 ,10 ],si8 >) -> !torch.vtensor <[1 ,10 ],si8 > {
4250
+ %0 = torch.aten.mm %arg0 , %arg1 : !torch.vtensor <[1 ,22 ],si8 >, !torch.vtensor <[22 ,10 ],si8 > -> !torch.vtensor <[1 ,10 ],si8 >
4251
+ return %0 : !torch.vtensor <[1 ,10 ],si8 >
4252
+ }
4253
+
4254
+ // -----
4255
+ // CHECK-LABEL: func.func @torch.aten.mm$f16
4256
+ // CHECK: tosa.matmul
4257
+ // CHECK-SAME: (tensor<1x1x22xf16>, tensor<1x22x10xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x1x10xf32>
4258
+ // CHECK-NOT: torch.aten.mm
4259
+ // CHECK: tosa.cast
4260
+ // CHECK-SAME: (tensor<1x10xf32>) -> tensor<1x10xf16>
4261
+ func.func @torch.aten.mm$f16 (%arg0: !torch.vtensor <[1 ,22 ],f16 >, %arg1: !torch.vtensor <[22 ,10 ],f16 >) -> !torch.vtensor <[1 ,10 ],f16 > {
4262
+ %4 = torch.aten.mm %arg0 , %arg1 : !torch.vtensor <[1 ,22 ],f16 >, !torch.vtensor <[22 ,10 ],f16 > -> !torch.vtensor <[1 ,10 ],f16 >
4263
+ return %4 : !torch.vtensor <[1 ,10 ],f16 >
4264
+ }
4265
+
4266
+ // -----
4267
+ // CHECK-LABEL: func.func @torch.aten.mm$bf16
4268
+ // CHECK: tosa.matmul
4269
+ // CHECK-SAME: (tensor<1x1x22xbf16>, tensor<1x22x10xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1x1x10xf32>
4270
+ // CHECK-NOT: torch.aten.mm
4271
+ // CHECK: tosa.cast
4272
+ // CHECK-SAME: (tensor<1x10xf32>) -> tensor<1x10xbf16>
4273
+ func.func @torch.aten.mm$bf16 (%arg0: !torch.vtensor <[1 ,22 ],bf16 >, %arg1: !torch.vtensor <[22 ,10 ],bf16 >) -> !torch.vtensor <[1 ,10 ],bf16 > {
4274
+ %4 = torch.aten.mm %arg0 , %arg1 : !torch.vtensor <[1 ,22 ],bf16 >, !torch.vtensor <[22 ,10 ],bf16 > -> !torch.vtensor <[1 ,10 ],bf16 >
4275
+ return %4 : !torch.vtensor <[1 ,10 ],bf16 >
4276
+ }
4277
+
4278
+ // -----
4279
+ // CHECK-LABEL: func.func @torch.aten.matmul$broadcast(
4280
+ // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[10,3,4],f32>,
4281
+ // CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[4],f32>) -> !torch.vtensor<[10,3],f32> {
4282
+ // CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[4],f32> -> tensor<4xf32>
4283
+ // CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[10,3,4],f32> -> tensor<10x3x4xf32>
4284
+ // CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
4285
+ // CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TENSOR]], %[[WTS_SHAPE]] : (tensor<4xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32>
4286
+ // CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 30, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
4287
+ // CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<10x3x4xf32>, !tosa.shape<3>) -> tensor<1x30x4xf32>
4288
+ // CHECK: %[[WTS_TRANSPOSE:.*]] = tosa.transpose %[[WTS_RESHAPE]] {perms = array<i32: 1, 0, 2>} : (tensor<1x4x1xf32>) -> tensor<4x1x1xf32>
4289
+ // CHECK: %[[WTS_SHAPE_2:.*]] = tosa.const_shape {values = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
4290
+ // CHECK: %[[WTS_RESHAPE_2:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE_2]] : (tensor<4x1x1xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32>
4291
+ // CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4292
+ // CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4293
+ // CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE_2]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x30x4xf32>, tensor<1x4x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x30x1xf32>
4294
+ // CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[10, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4295
+ // CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x30x1xf32>, !tosa.shape<2>) -> tensor<10x3xf32>
4296
+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[RES_RESHAPE]] : tensor<10x3xf32> -> !torch.vtensor<[10,3],f32>
4297
+ // CHECK: return %[[RES]]
4298
+ func.func @torch.aten.matmul$broadcast (%arg0: !torch.vtensor <[10 ,3 ,4 ],f32 >, %arg1: !torch.vtensor <[4 ],f32 >) -> !torch.vtensor <[10 ,3 ],f32 > {
4299
+ %0 = torch.aten.matmul %arg0 , %arg1 : !torch.vtensor <[10 ,3 ,4 ],f32 >, !torch.vtensor <[4 ],f32 > -> !torch.vtensor <[10 ,3 ],f32 >
4300
+ return %0 : !torch.vtensor <[10 ,3 ],f32 >
4301
+ }
4302
+
4303
+ // -----
4304
+ // CHECK-LABEL: func.func @torch.aten.linear$f16(
4305
+ // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[2,4],f16>,
4306
+ // CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[3,4],f16>,
4307
+ // CHECK-SAME: %[[BIAS:.*]]: !torch.vtensor<[3],f16>) -> !torch.vtensor<[2,3],f16> {
4308
+ // CHECK: %[[BIAS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[BIAS]] : !torch.vtensor<[3],f16> -> tensor<3xf16>
4309
+ // CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[3,4],f16> -> tensor<3x4xf16>
4310
+ // CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[2,4],f16> -> tensor<2x4xf16>
4311
+ // CHECK: %[[WTS_TRANSPOSE:.*]] = tosa.transpose %[[WTS_TENSOR]] {perms = array<i32: 1, 0>} : (tensor<3x4xf16>) -> tensor<4x3xf16>
4312
+ // CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 2, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
4313
+ // CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<2x4xf16>, !tosa.shape<3>) -> tensor<1x2x4xf16>
4314
+ // CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
4315
+ // CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE]] : (tensor<4x3xf16>, !tosa.shape<3>) -> tensor<1x4x3xf16>
4316
+ // CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
4317
+ // CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
4318
+ // CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x2x4xf16>, tensor<1x4x3xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x2x3xf32>
4319
+ // CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[2, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4320
+ // CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x2x3xf32>, !tosa.shape<2>) -> tensor<2x3xf32>
4321
+ // CHECK: %[[CAST:.*]] = tosa.cast %[[RES_RESHAPE]] : (tensor<2x3xf32>) -> tensor<2x3xf16>
4322
+ // CHECK: %[[BIAS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4323
+ // CHECK: %[[BIAS_RESHAPE:.*]] = tosa.reshape %[[BIAS_TENSOR]], %[[BIAS_SHAPE]] : (tensor<3xf16>, !tosa.shape<2>) -> tensor<1x3xf16>
4324
+ // CHECK: %[[ADD:.*]] = tosa.add %[[CAST]], %[[BIAS_RESHAPE]] : (tensor<2x3xf16>, tensor<1x3xf16>) -> tensor<2x3xf16>
4325
+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[ADD]] : tensor<2x3xf16> -> !torch.vtensor<[2,3],f16>
4326
+ // CHECK: return %[[RES]]
4327
+ func.func @torch.aten.linear$f16 (%arg0: !torch.vtensor <[2 ,4 ],f16 >, %arg1: !torch.vtensor <[3 ,4 ],f16 >, %arg2: !torch.vtensor <[3 ],f16 >) -> !torch.vtensor <[2 ,3 ],f16 > {
4328
+ %0 = torch.aten.linear %arg0 , %arg1 , %arg2 : !torch.vtensor <[2 ,4 ],f16 >, !torch.vtensor <[3 ,4 ],f16 >, !torch.vtensor <[3 ],f16 > -> !torch.vtensor <[2 ,3 ],f16 >
4329
+ return %0 : !torch.vtensor <[2 ,3 ],f16 >
4330
+ }
0 commit comments