@@ -3679,3 +3679,36 @@ func.func @test_rotary_embedding(%arg0: !torch.vtensor<[1,3,2,6],f32>, %arg1: !t
36793679 %4 = torch.operator " onnx.RotaryEmbedding" (%arg0 , %arg1 , %arg2 , %arg3 ) : (!torch.vtensor <[1 ,3 ,2 ,6 ],f32 >, !torch.vtensor <[1 ,2 ],si64 >, !torch.vtensor <[4 ,3 ],f32 >, !torch.vtensor <[4 ,3 ],f32 >) -> !torch.vtensor <[1 ,3 ,2 ,6 ],f32 >
36803680 return %4 : !torch.vtensor <[1 ,3 ,2 ,6 ],f32 >
36813681}
3682+
3683+ // -----
3684+
3685+ // CHECK-LABEL: @test_qlinearadd(
3686+ // CHECK-SAME: %[[A:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,4096],ui8>,
3687+ // CHECK-SAME: %[[A_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3688+ // CHECK-SAME: %[[A_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>,
3689+ // CHECK-SAME: %[[B:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[4096],ui8>,
3690+ // CHECK-SAME: %[[B_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3691+ // CHECK-SAME: %[[B_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>,
3692+ // CHECK-SAME: %[[C_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3693+ // CHECK-SAME: %[[C_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,4096],ui8>
3694+ func.func @test_qlinearadd (%arg0: !torch.vtensor <[1 ,4096 ],ui8 >, %arg1: !torch.vtensor <[],f32 >, %arg2: !torch.vtensor <[],ui8 >, %arg3: !torch.vtensor <[4096 ],ui8 >, %arg4: !torch.vtensor <[],f32 >, %arg5: !torch.vtensor <[],ui8 >, %arg6: !torch.vtensor <[],f32 >, %arg7: !torch.vtensor <[],ui8 >) -> !torch.vtensor <[1 ,4096 ],ui8 > attributes {torch.onnx_meta.opset_version = 10 : si64 } {
3695+ %0 = torch.operator " onnx.QLinearAdd" (%arg0 , %arg1 , %arg2 , %arg3 , %arg4 , %arg5 , %arg6 , %arg7 ) : (!torch.vtensor <[1 ,4096 ],ui8 >, !torch.vtensor <[],f32 >, !torch.vtensor <[],ui8 >, !torch.vtensor <[4096 ],ui8 >, !torch.vtensor <[],f32 >, !torch.vtensor <[],ui8 >, !torch.vtensor <[],f32 >, !torch.vtensor <[],ui8 >) -> !torch.vtensor <[1 ,4096 ],ui8 >
3696+ // CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
3697+ // CHECK-DAG: %[[AZP:.+]] = torch.aten.item %[[A_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3698+ // CHECK-DAG: %[[BZP:.+]] = torch.aten.item %[[B_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3699+ // CHECK-DAG: %[[CZP:.+]] = torch.aten.item %[[C_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3700+ // CHECK-DAG: %[[ASCALE:.+]] = torch.aten.item %[[A_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3701+ // CHECK-DAG: %[[BSCALE:.+]] = torch.aten.item %[[B_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3702+ // CHECK-DAG: %[[CSCALE:.+]] = torch.aten.item %[[C_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3703+ // CHECK-DAG: %[[A_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[A]], %[[ASCALE]], %[[AZP]] : !torch.vtensor<[1,4096],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,4096],!torch.quint8>
3704+ // CHECK-DAG: %[[B_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[B]], %[[BSCALE]], %[[BZP]] : !torch.vtensor<[4096],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4096],!torch.quint8>
3705+ // CHECK: %[[A_F32:.+]] = torch.aten.dequantize.self %[[A_QUANT]] : !torch.vtensor<[1,4096],!torch.quint8> -> !torch.vtensor<[1,4096],f32>
3706+ // CHECK: %[[B_F32:.+]] = torch.aten.dequantize.self %[[B_QUANT]] : !torch.vtensor<[4096],!torch.quint8> -> !torch.vtensor<[4096],f32>
3707+ // CHECK: %[[ALPHA:.+]] = torch.constant.float 1.000000e+00
3708+ // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[A_F32]], %[[B_F32]], %[[ALPHA]] : !torch.vtensor<[1,4096],f32>, !torch.vtensor<[4096],f32>, !torch.float -> !torch.vtensor<[1,4096],f32>
3709+ // CHECK: %[[DTY:.+]] = torch.constant.int 13
3710+ // CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[ADD]], %[[CSCALE]], %[[CZP]], %[[DTY]] : !torch.vtensor<[1,4096],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,4096],!torch.quint8>
3711+ // CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[1,4096],!torch.quint8> -> !torch.vtensor<[1,4096],ui8>
3712+ // CHECK: return %[[OUT]]
3713+ return %0 : !torch.vtensor <[1 ,4096 ],ui8 >
3714+ }
0 commit comments