Skip to content

Commit 64d8481

Browse files
GleasonKhanrach9
andauthored
Add remaining unary result accuracies (#2749)
Remaining implementation of https://github.com/openxla/stablehlo/blob/main/rfcs/20241015-result-accuracy.md Co-authored-by: Rachel Han <[email protected]>
1 parent 80929da commit 64d8481

23 files changed

+4049
-173
lines changed

stablehlo/dialect/StablehloOps.cpp

Lines changed: 150 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ limitations under the License.
8787
#include "stablehlo/dialect/AssemblyFormat.h"
8888
#include "stablehlo/dialect/Base.h"
8989
#include "stablehlo/dialect/StablehloBytecode.h"
90-
#include "stablehlo/dialect/StablehloOps.h"
9190
#include "stablehlo/dialect/StablehloOps.h.inc"
9291
#include "stablehlo/dialect/TypeInference.h"
9392

@@ -793,10 +792,6 @@ LogicalResult DotAlgorithmAttr::verify(
793792
allowImpreciseAccumulation);
794793
}
795794

796-
// ===----------------------------------------------------------------------===//
797-
// ExpOp
798-
//===----------------------------------------------------------------------===//
799-
800795
LogicalResult ResultAccuracyAttr::verify(
801796
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol,
802797
APFloat rtol, int64_t ulps, ResultAccuracyModeAttr mode) {
@@ -805,13 +800,158 @@ LogicalResult ResultAccuracyAttr::verify(
805800
stringifyResultAccuracyMode(mode.getValue()));
806801
}
807802

803+
// ===---------------------------------------------------------------------===//
804+
// CbrtOp
805+
//===----------------------------------------------------------------------===//
806+
807+
LogicalResult CbrtOp::verify() {
808+
if (auto attr = getResultAccuracyAttr()) {
809+
return ResultAccuracyAttr::verify([&] { return emitError(); },
810+
attr.getAtol(), attr.getRtol(),
811+
attr.getUlps(), attr.getMode());
812+
}
813+
return success();
814+
}
815+
816+
// ===---------------------------------------------------------------------===//
817+
// CosineOp
818+
//===----------------------------------------------------------------------===//
819+
820+
LogicalResult CosineOp::verify() {
821+
if (auto attr = getResultAccuracyAttr()) {
822+
return ResultAccuracyAttr::verify([&] { return emitError(); },
823+
attr.getAtol(), attr.getRtol(),
824+
attr.getUlps(), attr.getMode());
825+
}
826+
return success();
827+
}
828+
829+
// ===---------------------------------------------------------------------===//
830+
// ExpOp
831+
//===----------------------------------------------------------------------===//
832+
808833
LogicalResult ExpOp::verify() {
809834
if (auto attr = getResultAccuracyAttr()) {
810-
if (failed(ResultAccuracyAttr::verify([&] { return emitError(); },
811-
attr.getAtol(), attr.getRtol(),
812-
attr.getUlps(), attr.getMode()))) {
813-
return failure();
814-
}
835+
return ResultAccuracyAttr::verify([&] { return emitError(); },
836+
attr.getAtol(), attr.getRtol(),
837+
attr.getUlps(), attr.getMode());
838+
}
839+
return success();
840+
}
841+
842+
// ===---------------------------------------------------------------------===//
843+
// Expm1Op
844+
//===----------------------------------------------------------------------===//
845+
846+
LogicalResult Expm1Op::verify() {
847+
if (auto attr = getResultAccuracyAttr()) {
848+
return ResultAccuracyAttr::verify([&] { return emitError(); },
849+
attr.getAtol(), attr.getRtol(),
850+
attr.getUlps(), attr.getMode());
851+
}
852+
return success();
853+
}
854+
855+
// ===---------------------------------------------------------------------===//
856+
// LogOp
857+
//===----------------------------------------------------------------------===//
858+
859+
LogicalResult LogOp::verify() {
860+
if (auto attr = getResultAccuracyAttr()) {
861+
return ResultAccuracyAttr::verify([&] { return emitError(); },
862+
attr.getAtol(), attr.getRtol(),
863+
attr.getUlps(), attr.getMode());
864+
}
865+
return success();
866+
}
867+
868+
// ===---------------------------------------------------------------------===//
869+
// Log1pOp
870+
//===----------------------------------------------------------------------===//
871+
872+
LogicalResult Log1pOp::verify() {
873+
if (auto attr = getResultAccuracyAttr()) {
874+
return ResultAccuracyAttr::verify([&] { return emitError(); },
875+
attr.getAtol(), attr.getRtol(),
876+
attr.getUlps(), attr.getMode());
877+
}
878+
return success();
879+
}
880+
881+
// ===---------------------------------------------------------------------===//
882+
// LogisticOp
883+
//===----------------------------------------------------------------------===//
884+
885+
LogicalResult LogisticOp::verify() {
886+
if (auto attr = getResultAccuracyAttr()) {
887+
return ResultAccuracyAttr::verify([&] { return emitError(); },
888+
attr.getAtol(), attr.getRtol(),
889+
attr.getUlps(), attr.getMode());
890+
}
891+
return success();
892+
}
893+
894+
// ===---------------------------------------------------------------------===//
895+
// RsqrtOp
896+
//===----------------------------------------------------------------------===//
897+
898+
LogicalResult RsqrtOp::verify() {
899+
if (auto attr = getResultAccuracyAttr()) {
900+
return ResultAccuracyAttr::verify([&] { return emitError(); },
901+
attr.getAtol(), attr.getRtol(),
902+
attr.getUlps(), attr.getMode());
903+
}
904+
return success();
905+
}
906+
907+
// ===---------------------------------------------------------------------===//
908+
// SinOp
909+
//===----------------------------------------------------------------------===//
910+
911+
LogicalResult SineOp::verify() {
912+
if (auto attr = getResultAccuracyAttr()) {
913+
return ResultAccuracyAttr::verify([&] { return emitError(); },
914+
attr.getAtol(), attr.getRtol(),
915+
attr.getUlps(), attr.getMode());
916+
}
917+
return success();
918+
}
919+
920+
// ===---------------------------------------------------------------------===//
921+
// SqrtOp
922+
//===----------------------------------------------------------------------===//
923+
924+
LogicalResult SqrtOp::verify() {
925+
if (auto attr = getResultAccuracyAttr()) {
926+
return ResultAccuracyAttr::verify([&] { return emitError(); },
927+
attr.getAtol(), attr.getRtol(),
928+
attr.getUlps(), attr.getMode());
929+
}
930+
return success();
931+
}
932+
933+
// ===---------------------------------------------------------------------===//
934+
// TanOp
935+
//===----------------------------------------------------------------------===//
936+
937+
LogicalResult TanOp::verify() {
938+
if (auto attr = getResultAccuracyAttr()) {
939+
return ResultAccuracyAttr::verify([&] { return emitError(); },
940+
attr.getAtol(), attr.getRtol(),
941+
attr.getUlps(), attr.getMode());
942+
}
943+
return success();
944+
}
945+
946+
// ===---------------------------------------------------------------------===//
947+
// TanhOp
948+
//===----------------------------------------------------------------------===//
949+
950+
LogicalResult TanhOp::verify() {
951+
if (auto attr = getResultAccuracyAttr()) {
952+
return ResultAccuracyAttr::verify([&] { return emitError(); },
953+
attr.getAtol(), attr.getRtol(),
954+
attr.getUlps(), attr.getMode());
815955
}
816956
return success();
817957
}

stablehlo/dialect/StablehloOps.td

Lines changed: 102 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,16 @@ def StableHLO_CbrtOp: StableHLO_UnaryElementwiseOp<"cbrt",
240240
%result = stablehlo.cbrt %operand : tensor<4xf64>
241241
```
242242
}];
243+
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
244+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
245+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
246+
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
247+
let hasVerifier = 1;
248+
249+
let assemblyFormat = [{
250+
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
251+
}];
252+
243253
}
244254

245255
def StableHLO_CeilOp: StableHLO_UnaryElementwiseOp<"ceil",
@@ -310,6 +320,15 @@ def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine",
310320
%result = stablehlo.cosine %operand : tensor<2xf32>
311321
```
312322
}];
323+
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
324+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
325+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
326+
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
327+
let hasVerifier = 1;
328+
329+
let assemblyFormat = [{
330+
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
331+
}];
313332
}
314333

315334
def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential",
@@ -329,17 +348,9 @@ def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential",
329348
```
330349
}];
331350
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
332-
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr, "::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
351+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
352+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
333353
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
334-
let extraClassDeclaration = commonClassDeclaration # [{
335-
LogicalResult reifyReturnTypeShapes(
336-
OpBuilder& builder, ValueRange operands,
337-
SmallVectorImpl<Value>& reifiedReturnShapes) {
338-
return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(),
339-
operands.front(),
340-
&reifiedReturnShapes);
341-
}
342-
}];
343354
let hasVerifier = 1;
344355

345356
let assemblyFormat = [{
@@ -363,6 +374,15 @@ def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one",
363374
%result = stablehlo.exponential_minus_one %operand : tensor<2xf64>
364375
```
365376
}];
377+
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
378+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
379+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
380+
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
381+
let hasVerifier = 1;
382+
383+
let assemblyFormat = [{
384+
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
385+
}];
366386
}
367387

368388
def StableHLO_FloorOp: StableHLO_UnaryElementwiseOp<"floor",
@@ -440,6 +460,15 @@ def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log",
440460
%result = stablehlo.log %operand : tensor<2x2xf64>
441461
```
442462
}];
463+
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
464+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
465+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
466+
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
467+
let hasVerifier = 1;
468+
469+
let assemblyFormat = [{
470+
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
471+
}];
443472
}
444473

445474
def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one",
@@ -458,6 +487,15 @@ def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one",
458487
%result = stablehlo.log_plus_one %operand : tensor<5xf64>
459488
```
460489
}];
490+
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
491+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
492+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
493+
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
494+
let hasVerifier = 1;
495+
496+
let assemblyFormat = [{
497+
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
498+
}];
461499
}
462500

463501
def StableHLO_LogisticOp: StableHLO_UnaryElementwiseOp<"logistic",
@@ -476,6 +514,15 @@ def StableHLO_LogisticOp: StableHLO_UnaryElementwiseOp<"logistic",
476514
%result = stablehlo.logistic %operand : tensor<2x2xf64>
477515
```
478516
}];
517+
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
518+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
519+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
520+
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
521+
let hasVerifier = 1;
522+
523+
let assemblyFormat = [{
524+
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
525+
}];
479526
}
480527

481528
def StableHLO_NotOp: StableHLO_UnaryElementwiseOp<"not",
@@ -602,6 +649,15 @@ def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt",
602649
%result = stablehlo.rsqrt %operand : tensor<2x2xf32>
603650
```
604651
}];
652+
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
653+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
654+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
655+
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
656+
let hasVerifier = 1;
657+
658+
let assemblyFormat = [{
659+
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
660+
}];
605661
}
606662

607663
def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign",
@@ -637,6 +693,15 @@ def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine",
637693
%result = stablehlo.sine %operand : tensor<2xf32>
638694
```
639695
}];
696+
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
697+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
698+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
699+
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
700+
let hasVerifier = 1;
701+
702+
let assemblyFormat = [{
703+
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
704+
}];
640705
}
641706

642707
def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt",
@@ -655,6 +720,15 @@ def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt",
655720
%result = stablehlo.sqrt %operand : tensor<2x2xf32>
656721
```
657722
}];
723+
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
724+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
725+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
726+
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
727+
let hasVerifier = 1;
728+
729+
let assemblyFormat = [{
730+
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
731+
}];
658732
}
659733

660734
def StableHLO_TanOp: StableHLO_UnaryElementwiseOp<"tan",
@@ -673,6 +747,15 @@ def StableHLO_TanOp: StableHLO_UnaryElementwiseOp<"tan",
673747
%result = stablehlo.tan %operand : tensor<2x2xf64>
674748
```
675749
}];
750+
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
751+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
752+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
753+
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
754+
let hasVerifier = 1;
755+
756+
let assemblyFormat = [{
757+
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
758+
}];
676759
}
677760

678761
def StableHLO_TanhOp: StableHLO_UnaryElementwiseOp<"tanh",
@@ -691,6 +774,15 @@ def StableHLO_TanhOp: StableHLO_UnaryElementwiseOp<"tanh",
691774
%result = stablehlo.tanh %operand : tensor<2xf32>
692775
```
693776
}];
777+
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
778+
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr,
779+
"::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
780+
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
781+
let hasVerifier = 1;
782+
783+
let assemblyFormat = [{
784+
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
785+
}];
694786
}
695787

696788
//===----------------------------------------------------------------------===//

stablehlo/dialect/Version.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class Version {
3838
static FailureOr<Version> fromString(llvm::StringRef versionRef);
3939

4040
/// Return a Version representing the current VHLO dialect version.
41-
static Version getCurrentVersion() { return Version(1, 9, 8); }
41+
static Version getCurrentVersion() { return Version(1, 10, 0); }
4242

4343
/// Return a Version representing the minimum supported VHLO dialect version.
4444
static Version getMinimumVersion() { return Version(0, 9, 0); }

stablehlo/dialect/VhloDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def VHLO_Dialect : Dialect {
4848
1.7.0: Introduce `f8E4M3` and `f8E3M4` types.
4949
1.8.0: Introduce `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` types.
5050
1.9.0: Add `ResultAccuracy` attribute to `exp` op.
51+
1.10.0: Add `ResultAccuracy` attribute to `cbrt`, `cosine`, `exponential`, `exponential_minus_one`, `log`, `log_plus_one`, `logistic`, `rsqrt`, `sine`, `sqrt`, `tan` and `tanh` ops.
5152
}];
5253

5354
let useDefaultAttributePrinterParser = 0;

0 commit comments

Comments
 (0)