@@ -569,9 +569,10 @@ func.func @test_matmulinteger(%arg0: !torch.vtensor<[4,3],ui8>, %arg1: !torch.vt
569569 %0 = torch.operator " onnx.MatMulInteger" (%arg0 , %arg1 , %arg2 , %arg3 ) : (!torch.vtensor <[4 ,3 ],ui8 >, !torch.vtensor <[3 ,2 ],ui8 >, !torch.vtensor <[1 ],ui8 >, !torch.vtensor <[1 ],ui8 >) -> !torch.vtensor <[4 ,2 ],si32 >
570570 // CHECK: %[[LITEM:.+]] = torch.aten.item %arg2
571571 // CHECK: %[[RITEM:.+]] = torch.aten.item %arg3
572- // CHECK: %[[SCALE:.+]] = torch.constant.float 1.000000e+00
573- // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[LITEM]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8>
574- // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
572+ // CHECK: %[[L_SCALE:.+]] = torch.constant.float 1.000000e+00
573+ // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[L_SCALE]], %[[LITEM]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8>
574+ // CHECK: %[[R_SCALE:.+]] = torch.constant.float 1.000000e+00
575+ // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[R_SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
575576 // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]]
576577 // CHECK: return %[[MM]]
577578 return %0 : !torch.vtensor <[4 ,2 ],si32 >
@@ -584,13 +585,59 @@ func.func @test_matmulinteger_batched(%arg0: !torch.vtensor<[7,4,3],ui8>, %arg1:
584585 %0 = torch.operator " onnx.MatMulInteger" (%arg0 , %arg1 , %arg2 , %arg3 ) : (!torch.vtensor <[7 ,4 ,3 ],ui8 >, !torch.vtensor <[3 ,2 ],ui8 >, !torch.vtensor <[1 ],ui8 >, !torch.vtensor <[1 ],ui8 >) -> !torch.vtensor <[7 ,4 ,2 ],si32 >
585586 // CHECK: %[[LITEM:.+]] = torch.aten.item %arg2
586587 // CHECK: %[[RITEM:.+]] = torch.aten.item %arg3
587- // CHECK: %[[SCALE:.+]] = torch.constant.float 1.000000e+00
588- // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[LITEM]] : !torch.vtensor<[7,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[7,4,3],!torch.quint8>
589- // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
588+ // CHECK: %[[L_SCALE:.+]] = torch.constant.float 1.000000e+00
589+ // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[L_SCALE]], %[[LITEM]] : !torch.vtensor<[7,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[7,4,3],!torch.quint8>
590+ // CHECK: %[[R_SCALE:.+]] = torch.constant.float 1.000000e+00
591+ // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[R_SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
590592 // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]]
591593 // CHECK: return %[[MM]]
592594 return %0 : !torch.vtensor <[7 ,4 ,2 ],si32 >
593595}
596+
597+ // -----
598+
599+ // CHECK-LABEL: func.func @test_matmulinteger_non_scalar_lhsZp(
600+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[16,2],ui8>,
601+ // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,768],si8>,
602+ // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16],ui8>,
603+ // CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} {
604+ func.func @test_matmulinteger_non_scalar_lhsZp (%arg0: !torch.vtensor <[16 , 2 ],ui8 >, %arg1: !torch.vtensor <[2 ,768 ],si8 >, %arg2: !torch.vtensor <[16 ],ui8 >, %arg3: !torch.vtensor <[],si8 >) -> !torch.vtensor <[16 ,768 ],si32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 21 : si64 , torch.onnx_meta.producer_name = " pytorch" , torch.onnx_meta.producer_version = " 0.1.0" } {
605+ // CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[],si8> -> !torch.int
606+ // CHECK: %[[VAL_5:.*]] = torch.constant.int 6
607+ // CHECK: %[[VAL_6:.*]] = torch.constant.none
608+ // CHECK: %[[VAL_7:.*]] = torch.constant.int 0
609+ // CHECK: %[[VAL_8:.*]] = torch.aten.ones_like %[[VAL_2]], %[[VAL_5]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[16],ui8>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[16],f32>
610+ // CHECK: %[[VAL_9:.*]] = torch.aten._make_per_channel_quantized_tensor %[[VAL_0]], %[[VAL_8]], %[[VAL_2]], %[[VAL_7]] : !torch.vtensor<[16,2],ui8>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],ui8>, !torch.int -> !torch.vtensor<[16,2],!torch.quint8>
611+ // CHECK: %[[VAL_10:.*]] = torch.constant.float 1.000000e+00
612+ // CHECK: %[[VAL_11:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[VAL_1]], %[[VAL_10]], %[[VAL_4]] : !torch.vtensor<[2,768],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,768],!torch.qint8>
613+ // CHECK: %[[VAL_12:.*]] = torch.aten.matmul %[[VAL_9]], %[[VAL_11]] : !torch.vtensor<[16,2],!torch.quint8>, !torch.vtensor<[2,768],!torch.qint8> -> !torch.vtensor<[16,768],si32>
614+ // CHECK: return %[[VAL_12]] : !torch.vtensor<[16,768],si32>
615+ %0 = torch.operator " onnx.MatMulInteger" (%arg0 , %arg1 , %arg2 , %arg3 ) : (!torch.vtensor <[16 ,2 ],ui8 >, !torch.vtensor <[2 ,768 ],si8 >, !torch.vtensor <[16 ],ui8 >, !torch.vtensor <[],si8 >) -> !torch.vtensor <[16 ,768 ],si32 >
616+ return %0 : !torch.vtensor <[16 ,768 ],si32 >
617+ }
618+
619+ // -----
620+
621+ // CHECK-LABEL: func.func @test_matmulinteger_non_scalar_rhsZp(
622+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],ui8>,
623+ // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,768],si8>,
624+ // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],ui8>,
625+ // CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[768],si8>) -> !torch.vtensor<[?,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_met
626+ func.func @test_matmulinteger_non_scalar_rhsZp (%arg0: !torch.vtensor <[?,?],ui8 >, %arg1: !torch.vtensor <[2 ,768 ],si8 >, %arg2: !torch.vtensor <[],ui8 >, %arg3: !torch.vtensor <[768 ],si8 >) -> !torch.vtensor <[?,768 ],si32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 21 : si64 , torch.onnx_meta.producer_name = " pytorch" , torch.onnx_meta.producer_version = " 0.1.0" } {
627+ // CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[],ui8> -> !torch.int
628+ // CHECK: %[[VAL_5:.*]] = torch.constant.int 6
629+ // CHECK: %[[VAL_6:.*]] = torch.constant.none
630+ // CHECK: %[[VAL_7:.*]] = torch.constant.float 1.000000e+00
631+ // CHECK: %[[VAL_8:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[VAL_0]], %[[VAL_7]], %[[VAL_4]] : !torch.vtensor<[?,?],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,?],!torch.quint8>
632+ // CHECK: %[[VAL_9:.*]] = torch.constant.int 1
633+ // CHECK: %[[VAL_10:.*]] = torch.aten.ones_like %[[VAL_3]], %[[VAL_5]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[768],si8>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[768],f32>
634+ // CHECK: %[[VAL_11:.*]] = torch.aten._make_per_channel_quantized_tensor %[[VAL_1]], %[[VAL_10]], %[[VAL_3]], %[[VAL_9]] : !torch.vtensor<[2,768],si8>, !torch.vtensor<[768],f32>, !torch.vtensor<[768],si8>, !torch.int -> !torch.vtensor<[2,768],!torch.qint8>
635+ // CHECK: %[[VAL_12:.*]] = torch.aten.matmul %[[VAL_8]], %[[VAL_11]] : !torch.vtensor<[?,?],!torch.quint8>, !torch.vtensor<[2,768],!torch.qint8> -> !torch.vtensor<[?,768],si32>
636+ // CHECK: return %[[VAL_12]] : !torch.vtensor<[?,768],si32>
637+ %0 = torch.operator " onnx.MatMulInteger" (%arg0 , %arg1 , %arg2 , %arg3 ) : (!torch.vtensor <[?,?],ui8 >, !torch.vtensor <[2 ,768 ],si8 >, !torch.vtensor <[],ui8 >, !torch.vtensor <[768 ],si8 >) -> !torch.vtensor <[?,768 ],si32 >
638+ return %0 : !torch.vtensor <[?,768 ],si32 >
639+ }
640+
594641// -----
595642
596643// CHECK-LABEL: func.func @test_mul
0 commit comments