@@ -469,95 +469,49 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
469
469
return rewriter.notifyMatchFailure (
470
470
binder.op , " Unimplemented: expected 8 input operands" );
471
471
472
- Value a = operands[0 ];
473
- Value aScale = operands[1 ];
474
- Value aZp = operands[2 ];
475
- Value b = operands[3 ];
476
- Value bScale = operands[4 ];
477
- Value bZp = operands[5 ];
478
- Value cScale = operands[6 ];
479
- Value cZp = operands[7 ];
480
-
481
- auto check = [](Value v) {
482
- auto vTy = cast<Torch::ValueTensorType>(v.getType ());
483
- for (auto dim : vTy.getSizes ())
484
- if (dim != 1 )
485
- return false ;
486
- return true ;
487
- };
488
- if (!check (aScale) || !check (aZp) || !check (bScale) || !check (bZp) ||
489
- !check (cScale) || !check (cZp))
490
- return rewriter.notifyMatchFailure (
491
- binder.op , " Unsupported per-tensor quantization" );
492
-
493
- Value emptyList = rewriter.create <Torch::PrimListConstructOp>(
494
- binder.getLoc (),
495
- rewriter.getType <Torch::ListType>(
496
- rewriter.getType <Torch::IntType>()),
497
- ValueRange{});
498
- auto extract = [&rewriter, &binder, &emptyList](Value v) {
499
- auto vTy = cast<Torch::ValueTensorType>(v.getType ());
500
- if (!vTy.getSizes ().empty ()) {
501
- vTy = rewriter.getType <Torch::ValueTensorType>(
502
- ArrayRef<int64_t >({}), vTy.getOptionalDtype ());
503
- v = rewriter.create <Torch::AtenReshapeOp>(binder.getLoc (), vTy, v,
504
- emptyList);
505
- }
506
-
507
- Type extractTy = rewriter.getType <Torch::FloatType>();
508
- if (isa<IntegerType>(vTy.getDtype ()))
509
- extractTy = rewriter.getType <Torch::IntType>();
472
+ Value a, aScale, aZp, b, bScale, bZp, cScale, cZp;
510
473
511
- return rewriter.create <Torch::AtenItemOp>(binder.getLoc (), extractTy,
512
- v);
513
- };
514
-
515
- aZp = extract (aZp);
516
- bZp = extract (bZp);
517
- cZp = extract (cZp);
518
-
519
- aScale = extract (aScale);
520
- bScale = extract (bScale);
521
- cScale = extract (cScale);
522
-
523
- auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
524
- Value zp) -> Value {
525
- auto ty = cast<Torch::ValueTensorType>(v.getType ());
526
- auto newTy = getQTorchTypeFromTorchIntType (ty);
527
- return rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
528
- binder.getLoc (), newTy, v, scale, zp);
529
- };
530
-
531
- a = makePerTensor (a, aScale, aZp);
532
- b = makePerTensor (b, bScale, bZp);
474
+ if (failed (extractPerTensorQuantizationArguments (
475
+ rewriter, loc, /* scale=*/ operands[1 ],
476
+ /* zero_point=*/ operands[2 ], aScale, aZp)))
477
+ return rewriter.notifyMatchFailure (
478
+ binder.op , " Incompatible arguments for per-tensor quantization" );
533
479
534
- auto aTy = dyn_cast<Torch::ValueTensorType>(a.getType ());
535
- if (!aTy || !aTy.hasSizes ())
480
+ if (failed (extractPerTensorQuantizationArguments (
481
+ rewriter, loc, /* scale=*/ operands[4 ],
482
+ /* zero_point=*/ operands[5 ], bScale, bZp)))
536
483
return rewriter.notifyMatchFailure (
537
- binder.op , " Expected input argument `a` to have sizes " );
484
+ binder.op , " Incompatible arguments for per-tensor quantization " );
538
485
539
- aTy = rewriter.getType <Torch::ValueTensorType>(aTy.getOptionalSizes (),
540
- rewriter.getF32Type ());
541
- a = rewriter.create <Torch::AtenDequantizeSelfOp>(binder.getLoc (), aTy,
542
- a);
486
+ if (failed (extractPerTensorQuantizationArguments (
487
+ rewriter, loc, /* scale=*/ operands[6 ],
488
+ /* zero_point=*/ operands[7 ], cScale, cZp)))
489
+ return rewriter.notifyMatchFailure (
490
+ binder.op , " Incompatible arguments for per-tensor quantization" );
543
491
544
- auto bTy = dyn_cast<Torch::ValueTensorType>(b.getType ());
545
- if (!bTy || !bTy.hasSizes ())
492
+ if (failed (createDequantizeTensor (rewriter, loc, /* input=*/ operands[0 ],
493
+ /* scale=*/ aScale, /* zero_point=*/ aZp,
494
+ /* output=*/ a)))
546
495
return rewriter.notifyMatchFailure (
547
- binder.op , " Expected input argument `b` to have sizes" );
496
+ binder.op , " Failed to dequantize the input tensor `a` because of "
497
+ " missing sizes" );
548
498
549
- bTy = rewriter.getType <Torch::ValueTensorType>(bTy.getOptionalSizes (),
550
- rewriter.getF32Type ());
551
- b = rewriter.create <Torch::AtenDequantizeSelfOp>(binder.getLoc (), bTy,
552
- b);
499
+ if (failed (createDequantizeTensor (rewriter, loc, /* input=*/ operands[3 ],
500
+ /* scale=*/ bScale, /* zero_point=*/ bZp,
501
+ /* output=*/ b)))
502
+ return rewriter.notifyMatchFailure (
503
+ binder.op , " Failed to dequantize the input tensor `b` because of "
504
+ " missing sizes" );
553
505
506
+ // Computing the result of "Add".
554
507
auto cTy = rewriter.getType <Torch::ValueTensorType>(
555
508
resultType.getOptionalSizes (), rewriter.getF32Type ());
556
509
Value alpha = rewriter.create <Torch::ConstantFloatOp>(
557
510
loc, rewriter.getF64FloatAttr (1.0 ));
558
511
Value c = rewriter.create <Torch::AtenAddTensorOp>(binder.getLoc (), cTy,
559
512
a, b, alpha);
560
513
514
+ // Quantizing the result of "Add" operation.
561
515
cTy = dyn_cast<Torch::ValueTensorType>(
562
516
getQTorchTypeFromTorchIntType (resultType));
563
517
Value dtyVal = rewriter.create <Torch::ConstantIntOp>(
@@ -588,11 +542,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
588
542
return rewriter.notifyMatchFailure (
589
543
binder.op , " Unimplemented: expected 5 input operands" );
590
544
591
- Value x = operands[0 ];
592
- Value xScale = operands[1 ];
593
- Value xZp = operands[2 ];
594
- Value yScale = operands[3 ];
595
- Value yZp = operands[4 ];
545
+ Value x, xScale, xZp, yScale, yZp;
596
546
597
547
if (failed (extractPerTensorQuantizationArguments (
598
548
rewriter, loc, /* scale=*/ operands[1 ],
@@ -606,18 +556,12 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
606
556
return rewriter.notifyMatchFailure (
607
557
binder.op , " Incompatible arguments for per-tensor quantization" );
608
558
609
- auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType ());
610
- if (!xTy || !xTy.hasSizes ())
559
+ if (failed (createDequantizeTensor (rewriter, loc, /* input=*/ operands[0 ],
560
+ /* scale=*/ xScale, /* zero_point=*/ xZp,
561
+ /* output=*/ x)))
611
562
return rewriter.notifyMatchFailure (
612
- binder.op , " Expected input argument `x` to have sizes" );
613
-
614
- xTy = getQTorchTypeFromTorchIntType (xTy);
615
- x = rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
616
- loc, xTy, x, xScale, xZp);
617
- xTy = rewriter.getType <Torch::ValueTensorType>(xTy.getSizes (),
618
- rewriter.getF32Type ());
619
- // Dequantizing the input tensor `x`.
620
- x = rewriter.create <Torch::AtenDequantizeSelfOp>(loc, xTy, x);
563
+ binder.op , " Failed to dequantize the input tensor `x` because of "
564
+ " missing sizes" );
621
565
622
566
// Computing the LeakyRelu result.
623
567
Value constAlpha = rewriter.create <Torch::ConstantFloatOp>(
@@ -670,16 +614,8 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
670
614
binder.op , " Incompatible number of input operands, scales and/or "
671
615
" zero-points" );
672
616
673
- auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
674
- Value zp) -> Value {
675
- auto ty = cast<Torch::ValueTensorType>(v.getType ());
676
- auto newTy = getQTorchTypeFromTorchIntType (ty);
677
- return rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
678
- binder.getLoc (), newTy, v, scale, zp);
679
- };
680
-
681
- // Preparing the quantized inputs.
682
- SmallVector<Value> quantizedInputs;
617
+ // Preparing the dequantized inputs.
618
+ SmallVector<Value> dequantizedInputs;
683
619
for (unsigned i = 0 ; i < numInputs; i++) {
684
620
Value scale, zeroPoint;
685
621
if (failed (extractPerTensorQuantizationArguments (
@@ -689,24 +625,15 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
689
625
binder.op , " Incompatible scale and zero-points argument for "
690
626
" per-tensor quantization" );
691
627
692
- quantizedInputs.push_back (makePerTensor (inputs[i], scale, zeroPoint));
693
- }
694
-
695
- // Dequantizing the inputs.
696
- SmallVector<Value> dequantizedInputs;
697
- for (unsigned i = 0 ; i < numInputs; i++) {
698
- Torch::ValueTensorType inputTy =
699
- dyn_cast<Torch::ValueTensorType>(quantizedInputs[i].getType ());
700
- if (!inputTy || !inputTy.hasSizes ())
628
+ Value dequantizedInput;
629
+ if (failed (createDequantizeTensor (rewriter, loc, inputs[i], scale,
630
+ zeroPoint,
631
+ /* output=*/ dequantizedInput)))
701
632
return rewriter.notifyMatchFailure (
702
- binder.op , " Expected tensor input operands to be concatenated "
703
- " to have sizes" );
704
-
705
- inputTy = rewriter.getType <Torch::ValueTensorType>(
706
- inputTy.getOptionalSizes (), rewriter.getF32Type ());
707
- dequantizedInputs.push_back (
708
- rewriter.create <Torch::AtenDequantizeSelfOp>(loc, inputTy,
709
- quantizedInputs[i]));
633
+ binder.op , " Failed to dequantize the input tensor because of "
634
+ " missing sizes" );
635
+
636
+ dequantizedInputs.push_back (dequantizedInput);
710
637
}
711
638
712
639
// Concatenating the inputs.
@@ -764,8 +691,19 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
764
691
binder.op ,
765
692
" Unimplemented: support not present for channels_last attribute" );
766
693
767
- Value x = operands[0 ];
768
- Value xScale, xZp, yScale, yZp;
694
+ auto xTy = dyn_cast<Torch::ValueTensorType>(operands[0 ].getType ());
695
+ if (!xTy || !xTy.hasSizes ())
696
+ return rewriter.notifyMatchFailure (
697
+ binder.op , " Expected input argument `x` to have sizes" );
698
+ ArrayRef<int64_t > inputShape = xTy.getSizes ();
699
+
700
+ if (!resultType || !resultType.hasSizes ()) {
701
+ return rewriter.notifyMatchFailure (
702
+ binder.op , " Expected result type having sizes" );
703
+ }
704
+ ArrayRef<int64_t > resultShape = resultType.getSizes ();
705
+
706
+ Value x, xScale, xZp, yScale, yZp;
769
707
770
708
if (failed (extractPerTensorQuantizationArguments (
771
709
rewriter, loc, /* scale=*/ operands[1 ],
@@ -779,25 +717,12 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
779
717
return rewriter.notifyMatchFailure (
780
718
binder.op , " Incompatible arguments for per-tensor quantization" );
781
719
782
- auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType ());
783
- if (!xTy || !xTy.hasSizes ())
784
- return rewriter.notifyMatchFailure (
785
- binder.op , " Expected input argument `x` to have sizes" );
786
- ArrayRef<int64_t > inputShape = xTy.getSizes ();
787
-
788
- xTy = getQTorchTypeFromTorchIntType (xTy);
789
- x = rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
790
- loc, xTy, x, xScale, xZp);
791
- xTy = rewriter.getType <Torch::ValueTensorType>(inputShape,
792
- rewriter.getF32Type ());
793
- // Dequantizing the input tensor `x`.
794
- x = rewriter.create <Torch::AtenDequantizeSelfOp>(loc, xTy, x);
795
-
796
- if (!resultType || !resultType.hasSizes ()) {
720
+ if (failed (createDequantizeTensor (rewriter, loc, /* input=*/ operands[0 ],
721
+ /* scale=*/ xScale, /* zero_point=*/ xZp,
722
+ /* output=*/ x)))
797
723
return rewriter.notifyMatchFailure (
798
- binder.op , " Expected result type having sizes" );
799
- }
800
- ArrayRef<int64_t > resultShape = resultType.getSizes ();
724
+ binder.op , " Failed to dequantize the input tensor `x` because of "
725
+ " missing sizes" );
801
726
802
727
// Computing the AvgPool result.
803
728
SmallVector<Value> cstKernel, cstPadding, cstStrides;
@@ -888,8 +813,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
888
813
return rewriter.notifyMatchFailure (
889
814
binder.op , " Unimplemented: expected 5 input operands" );
890
815
891
- Value x = operands[0 ];
892
- Value xScale, xZp, yScale, yZp;
816
+ Value x, xScale, xZp, yScale, yZp;
893
817
894
818
if (failed (extractPerTensorQuantizationArguments (
895
819
rewriter, loc, /* scale=*/ operands[1 ],
@@ -903,18 +827,12 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
903
827
return rewriter.notifyMatchFailure (
904
828
binder.op , " Incompatible arguments for per-tensor quantization" );
905
829
906
- auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType ());
907
- if (!xTy || !xTy.hasSizes ())
830
+ if (failed (createDequantizeTensor (rewriter, loc, /* input=*/ operands[0 ],
831
+ /* scale=*/ xScale, /* zero_point=*/ xZp,
832
+ /* output=*/ x)))
908
833
return rewriter.notifyMatchFailure (
909
- binder.op , " Expected input argument `x` to have sizes" );
910
-
911
- xTy = getQTorchTypeFromTorchIntType (xTy);
912
- x = rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
913
- loc, xTy, x, xScale, xZp);
914
- xTy = rewriter.getType <Torch::ValueTensorType>(xTy.getSizes (),
915
- rewriter.getF32Type ());
916
- // Dequantizing the input tensor `x`.
917
- x = rewriter.create <Torch::AtenDequantizeSelfOp>(loc, xTy, x);
834
+ binder.op , " Failed to dequantize the input tensor `x` because of "
835
+ " missing sizes" );
918
836
919
837
// Computing the Sigmoid result.
920
838
auto yTy = rewriter.getType <Torch::ValueTensorType>(
@@ -958,8 +876,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
958
876
return rewriter.notifyMatchFailure (
959
877
binder.op , " Unimplemented: expected 5 input operands" );
960
878
961
- Value x = operands[0 ];
962
- Value xScale, xZp, yScale, yZp;
879
+ Value x, xScale, xZp, yScale, yZp;
963
880
964
881
if (failed (extractPerTensorQuantizationArguments (
965
882
rewriter, loc, /* scale=*/ operands[1 ],
@@ -973,18 +890,12 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
973
890
return rewriter.notifyMatchFailure (
974
891
binder.op , " Incompatible arguments for per-tensor quantization" );
975
892
976
- auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType ());
977
- if (!xTy || !xTy.hasSizes ())
893
+ if (failed (createDequantizeTensor (rewriter, loc, /* input=*/ operands[0 ],
894
+ /* scale=*/ xScale, /* zero_point=*/ xZp,
895
+ /* output=*/ x)))
978
896
return rewriter.notifyMatchFailure (
979
- binder.op , " Expected input argument `x` to have sizes" );
980
-
981
- xTy = getQTorchTypeFromTorchIntType (xTy);
982
- x = rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
983
- loc, xTy, x, xScale, xZp);
984
- xTy = rewriter.getType <Torch::ValueTensorType>(xTy.getSizes (),
985
- rewriter.getF32Type ());
986
- // Dequantizing the input tensor `x`.
987
- x = rewriter.create <Torch::AtenDequantizeSelfOp>(loc, xTy, x);
897
+ binder.op , " Failed to dequantize the input tensor `x` because of "
898
+ " missing sizes" );
988
899
989
900
// Creating Onnx.AveragePool op.
990
901
llvm::SmallVector<Value> newOperands = {x};
0 commit comments