@@ -231,9 +231,10 @@ LogicalResult TransOp::verify() {
231
231
return success ();
232
232
}
233
233
234
- LogicalResult TransOp::inferReturnTypes (
235
- MLIRContext *context, std::optional<Location> location,
236
- TransOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
234
+ LogicalResult
235
+ TransOp::inferReturnTypes (MLIRContext *context, std::optional<Location> loc,
236
+ TransOp::Adaptor adaptor,
237
+ SmallVectorImpl<Type> &inferredReturnTypes) {
237
238
238
239
// type is the same as the input
239
240
auto argTy = cast<RankedTensorType>(adaptor.getSrc ().getType ());
@@ -247,9 +248,8 @@ LogicalResult TransOp::inferReturnTypes(
247
248
if (argEncoding) {
248
249
Dialect &dialect = argEncoding.getDialect ();
249
250
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
250
- if (inferLayoutInterface
251
- ->inferTransOpEncoding (argEncoding, shape, order, retEncoding)
252
- .failed ()) {
251
+ if (failed (inferLayoutInterface->inferTransOpEncoding (
252
+ argEncoding, shape, order, retEncoding, loc))) {
253
253
return failure ();
254
254
}
255
255
}
@@ -389,7 +389,8 @@ LogicalResult MakeRangeOp::verify() {
389
389
390
390
// -- ReduceOp --
391
391
static LogicalResult
392
- inferReduceReturnShape (RankedTensorType argTy, Type retEltTy, int axis,
392
+ inferReduceReturnShape (std::optional<Location> loc, RankedTensorType argTy,
393
+ Type retEltTy, int axis,
393
394
SmallVectorImpl<Type> &inferredReturnTypes) {
394
395
auto retShape = argTy.getShape ().vec ();
395
396
retShape.erase (retShape.begin () + axis);
@@ -404,10 +405,8 @@ inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
404
405
if (argEncoding) {
405
406
Dialect &dialect = argEncoding.getDialect ();
406
407
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
407
- if (inferLayoutInterface
408
- ->inferReduceOpEncoding (argEncoding, axis, retEncoding)
409
- .failed ()) {
410
- llvm::report_fatal_error (" failed to infer layout for ReduceOp" );
408
+ if (failed (inferLayoutInterface->inferReduceOpEncoding (
409
+ argEncoding, axis, retEncoding, loc))) {
411
410
return failure ();
412
411
}
413
412
}
@@ -418,29 +417,18 @@ inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
418
417
return success ();
419
418
}
420
419
421
- void ReduceOp::build (OpBuilder &builder, OperationState &state,
422
- ValueRange operands, int axis) {
423
- SmallVector<Type> inferredReturnTypes;
424
- for (unsigned i = 0 ; i < operands.size (); ++i) {
425
- auto argTy = cast<RankedTensorType>(operands[i].getType ());
426
- auto retEltTy = argTy.getElementType ();
427
- (void )inferReduceReturnShape (argTy, retEltTy, axis, inferredReturnTypes);
428
- }
429
-
430
- ReduceOp::build (builder, state, inferredReturnTypes, operands, axis);
431
- }
432
-
433
- LogicalResult ReduceOp::inferReturnTypes (
434
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
435
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
436
- SmallVectorImpl<Type> &inferredReturnTypes) {
420
+ LogicalResult
421
+ ReduceOp::inferReturnTypes (MLIRContext *context, std::optional<Location> loc,
422
+ ValueRange operands, DictionaryAttr attributes,
423
+ OpaqueProperties properties, RegionRange regions,
424
+ SmallVectorImpl<Type> &inferredReturnTypes) {
437
425
Properties *prop = properties.as <Properties *>();
438
426
int axis = prop->axis .getInt ();
439
427
for (auto arg : operands) {
440
428
auto argTy = cast<RankedTensorType>(arg.getType ());
441
429
auto retEltTy = argTy.getElementType ();
442
- if (inferReduceReturnShape (argTy, retEltTy, axis, inferredReturnTypes)
443
- . failed ( )) {
430
+ if (failed ( inferReduceReturnShape (loc, argTy, retEltTy, axis,
431
+ inferredReturnTypes) )) {
444
432
return failure ();
445
433
}
446
434
}
@@ -636,9 +624,8 @@ LogicalResult ExpandDimsOp::inferReturnTypes(
636
624
if (argEncoding) {
637
625
Dialect &dialect = argEncoding.getDialect ();
638
626
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
639
- if (inferLayoutInterface
640
- ->inferExpandDimsOpEncoding (argEncoding, axis, retEncoding, loc)
641
- .failed ())
627
+ if (failed (inferLayoutInterface->inferExpandDimsOpEncoding (
628
+ argEncoding, axis, retEncoding, loc)))
642
629
return emitOptionalError (loc, " failed to infer layout for ExpandDimsOp" );
643
630
}
644
631
// create type
@@ -674,10 +661,10 @@ LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op,
674
661
// Infer the encoding of the new expand op, if encodings are present.
675
662
Attribute newExpandEnc;
676
663
if (auto srcEnc = srcTy.getEncoding ()) {
677
- if (cast<DialectInferLayoutInterface>(& srcEnc.getDialect ())
678
- -> inferExpandDimsOpEncoding (srcEnc, op. getAxis (), newExpandEnc,
679
- op. getLoc ())
680
- . failed ( )) {
664
+ Dialect &dialect = srcEnc.getDialect ();
665
+ auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
666
+ if ( failed (inferLayoutInterface-> inferExpandDimsOpEncoding (
667
+ srcEnc, op. getAxis (), newExpandEnc, op. getLoc ()) )) {
681
668
return emitOptionalError (op.getLoc (),
682
669
" failed to infer layout for ExpandDimsOp" );
683
670
}
@@ -719,9 +706,8 @@ OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) {
719
706
// -- ReshapeOp --
720
707
721
708
void ReshapeOp::build (OpBuilder &builder, OperationState &state,
722
- ArrayRef<int64_t > shape,
723
- TypedValue<RankedTensorType> src) {
724
- auto srcTy = src.getType ();
709
+ ArrayRef<int64_t > shape, Value src, bool allowReorder) {
710
+ auto srcTy = cast<RankedTensorType>(src.getType ());
725
711
auto srcEnc = srcTy.getEncoding ();
726
712
Attribute dstEnc;
727
713
if (srcEnc) {
@@ -731,7 +717,7 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &state,
731
717
assert (succeeded (result));
732
718
}
733
719
auto dstTy = RankedTensorType::get (shape, srcTy.getElementType (), dstEnc);
734
- build (builder, state, dstTy, src);
720
+ build (builder, state, dstTy, src, allowReorder );
735
721
}
736
722
737
723
LogicalResult ReshapeOp::canonicalize (ReshapeOp op, PatternRewriter &rewriter) {
@@ -794,14 +780,14 @@ LogicalResult ReshapeOp::verify() {
794
780
// Check that we can infer the dst encoding from the src encoding
795
781
// and that the inferred dst encoding is the same as the given dst encoding
796
782
Attribute inferredDstEnc;
797
- auto result =
798
- cast<DialectInferLayoutInterface>(&srcEnc.getDialect ())
799
- ->inferReshapeOpEncoding (srcTy. getShape (), srcEnc, dstTy. getShape (),
800
- inferredDstEnc, getLoc ());
801
- assert ( succeeded (result));
802
- return cast<DialectInferLayoutInterface>(&srcEnc. getDialect ())
803
- ->verifyLayoutsAreEqual (dstTy. getShape (), inferredDstEnc, dstEnc,
804
- getLoc ());
783
+ auto layoutInterface =
784
+ cast<DialectInferLayoutInterface>(&srcEnc.getDialect ());
785
+ auto result = layoutInterface ->inferReshapeOpEncoding (
786
+ srcTy. getShape (), srcEnc, dstTy. getShape (), inferredDstEnc, getLoc ());
787
+ if ( failed (result))
788
+ return failure ();
789
+ return layoutInterface ->verifyLayoutsAreEqual (
790
+ dstTy. getShape (), inferredDstEnc, dstEnc, getLoc ());
805
791
}
806
792
807
793
// -- FpToFpOp --
@@ -1092,11 +1078,10 @@ void JoinOp::build(OpBuilder &builder, OperationState &state, Value lhs,
1092
1078
Attribute srcEnc = lhsTy.getEncoding ();
1093
1079
Attribute retEnc;
1094
1080
if (srcEnc) {
1095
- if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect ())
1096
- ->inferDefaultJoinOpEncoding (srcEnc, retEnc, lhsTy.getShape (),
1097
- /* loc=*/ std::nullopt)
1098
- .failed ()) {
1099
- assert (false && " failed to infer join encoding" );
1081
+ if (failed (cast<DialectInferLayoutInterface>(&srcEnc.getDialect ())
1082
+ ->inferDefaultJoinOpEncoding (
1083
+ srcEnc, retEnc, lhsTy.getShape (), state.location ))) {
1084
+ llvm_unreachable (" failed to infer join encoding" );
1100
1085
}
1101
1086
}
1102
1087
auto retTy = RankedTensorType::get (retShape, lhsTy.getElementType (), retEnc);
0 commit comments