@@ -231,9 +231,10 @@ LogicalResult TransOp::verify() {
231231 return success ();
232232}
233233
234- LogicalResult TransOp::inferReturnTypes (
235- MLIRContext *context, std::optional<Location> location,
236- TransOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
234+ LogicalResult
235+ TransOp::inferReturnTypes (MLIRContext *context, std::optional<Location> loc,
236+ TransOp::Adaptor adaptor,
237+ SmallVectorImpl<Type> &inferredReturnTypes) {
237238
238239 // type is the same as the input
239240 auto argTy = cast<RankedTensorType>(adaptor.getSrc ().getType ());
@@ -247,9 +248,8 @@ LogicalResult TransOp::inferReturnTypes(
247248 if (argEncoding) {
248249 Dialect &dialect = argEncoding.getDialect ();
249250 auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
250- if (inferLayoutInterface
251- ->inferTransOpEncoding (argEncoding, shape, order, retEncoding)
252- .failed ()) {
251+ if (failed (inferLayoutInterface->inferTransOpEncoding (
252+ argEncoding, shape, order, retEncoding, loc))) {
253253 return failure ();
254254 }
255255 }
@@ -389,7 +389,8 @@ LogicalResult MakeRangeOp::verify() {
389389
390390// -- ReduceOp --
391391static LogicalResult
392- inferReduceReturnShape (RankedTensorType argTy, Type retEltTy, int axis,
392+ inferReduceReturnShape (std::optional<Location> loc, RankedTensorType argTy,
393+ Type retEltTy, int axis,
393394 SmallVectorImpl<Type> &inferredReturnTypes) {
394395 auto retShape = argTy.getShape ().vec ();
395396 retShape.erase (retShape.begin () + axis);
@@ -404,10 +405,8 @@ inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
404405 if (argEncoding) {
405406 Dialect &dialect = argEncoding.getDialect ();
406407 auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
407- if (inferLayoutInterface
408- ->inferReduceOpEncoding (argEncoding, axis, retEncoding)
409- .failed ()) {
410- llvm::report_fatal_error (" failed to infer layout for ReduceOp" );
408+ if (failed (inferLayoutInterface->inferReduceOpEncoding (
409+ argEncoding, axis, retEncoding, loc))) {
411410 return failure ();
412411 }
413412 }
@@ -418,29 +417,18 @@ inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
418417 return success ();
419418}
420419
421- void ReduceOp::build (OpBuilder &builder, OperationState &state,
422- ValueRange operands, int axis) {
423- SmallVector<Type> inferredReturnTypes;
424- for (unsigned i = 0 ; i < operands.size (); ++i) {
425- auto argTy = cast<RankedTensorType>(operands[i].getType ());
426- auto retEltTy = argTy.getElementType ();
427- (void )inferReduceReturnShape (argTy, retEltTy, axis, inferredReturnTypes);
428- }
429-
430- ReduceOp::build (builder, state, inferredReturnTypes, operands, axis);
431- }
432-
433- LogicalResult ReduceOp::inferReturnTypes (
434- MLIRContext *context, std::optional<Location> location, ValueRange operands,
435- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
436- SmallVectorImpl<Type> &inferredReturnTypes) {
420+ LogicalResult
421+ ReduceOp::inferReturnTypes (MLIRContext *context, std::optional<Location> loc,
422+ ValueRange operands, DictionaryAttr attributes,
423+ OpaqueProperties properties, RegionRange regions,
424+ SmallVectorImpl<Type> &inferredReturnTypes) {
437425 Properties *prop = properties.as <Properties *>();
438426 int axis = prop->axis .getInt ();
439427 for (auto arg : operands) {
440428 auto argTy = cast<RankedTensorType>(arg.getType ());
441429 auto retEltTy = argTy.getElementType ();
442- if (inferReduceReturnShape (argTy, retEltTy, axis, inferredReturnTypes)
443- . failed ( )) {
430+ if (failed ( inferReduceReturnShape (loc, argTy, retEltTy, axis,
431+ inferredReturnTypes) )) {
444432 return failure ();
445433 }
446434 }
@@ -636,9 +624,8 @@ LogicalResult ExpandDimsOp::inferReturnTypes(
636624 if (argEncoding) {
637625 Dialect &dialect = argEncoding.getDialect ();
638626 auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
639- if (inferLayoutInterface
640- ->inferExpandDimsOpEncoding (argEncoding, axis, retEncoding, loc)
641- .failed ())
627+ if (failed (inferLayoutInterface->inferExpandDimsOpEncoding (
628+ argEncoding, axis, retEncoding, loc)))
642629 return emitOptionalError (loc, " failed to infer layout for ExpandDimsOp" );
643630 }
644631 // create type
@@ -674,10 +661,10 @@ LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op,
674661 // Infer the encoding of the new expand op, if encodings are present.
675662 Attribute newExpandEnc;
676663 if (auto srcEnc = srcTy.getEncoding ()) {
677- if (cast<DialectInferLayoutInterface>(& srcEnc.getDialect ())
678- -> inferExpandDimsOpEncoding (srcEnc, op. getAxis (), newExpandEnc,
679- op. getLoc ())
680- . failed ( )) {
664+ Dialect &dialect = srcEnc.getDialect ();
665+ auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
666+ if ( failed (inferLayoutInterface-> inferExpandDimsOpEncoding (
667+ srcEnc, op. getAxis (), newExpandEnc, op. getLoc ()) )) {
681668 return emitOptionalError (op.getLoc (),
682669 " failed to infer layout for ExpandDimsOp" );
683670 }
@@ -719,9 +706,8 @@ OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) {
719706// -- ReshapeOp --
720707
721708void ReshapeOp::build (OpBuilder &builder, OperationState &state,
722- ArrayRef<int64_t > shape,
723- TypedValue<RankedTensorType> src) {
724- auto srcTy = src.getType ();
709+ ArrayRef<int64_t > shape, Value src, bool allowReorder) {
710+ auto srcTy = cast<RankedTensorType>(src.getType ());
725711 auto srcEnc = srcTy.getEncoding ();
726712 Attribute dstEnc;
727713 if (srcEnc) {
@@ -731,7 +717,7 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &state,
731717 assert (succeeded (result));
732718 }
733719 auto dstTy = RankedTensorType::get (shape, srcTy.getElementType (), dstEnc);
734- build (builder, state, dstTy, src);
720+ build (builder, state, dstTy, src, allowReorder );
735721}
736722
737723LogicalResult ReshapeOp::canonicalize (ReshapeOp op, PatternRewriter &rewriter) {
@@ -794,14 +780,14 @@ LogicalResult ReshapeOp::verify() {
794780 // Check that we can infer the dst encoding from the src encoding
795781 // and that the inferred dst encoding is the same as the given dst encoding
796782 Attribute inferredDstEnc;
797- auto result =
798- cast<DialectInferLayoutInterface>(&srcEnc.getDialect ())
799- ->inferReshapeOpEncoding (srcTy. getShape (), srcEnc, dstTy. getShape (),
800- inferredDstEnc, getLoc ());
801- assert ( succeeded (result));
802- return cast<DialectInferLayoutInterface>(&srcEnc. getDialect ())
803- ->verifyLayoutsAreEqual (dstTy. getShape (), inferredDstEnc, dstEnc,
804- getLoc ());
783+ auto layoutInterface =
784+ cast<DialectInferLayoutInterface>(&srcEnc.getDialect ());
785+ auto result = layoutInterface ->inferReshapeOpEncoding (
786+ srcTy. getShape (), srcEnc, dstTy. getShape (), inferredDstEnc, getLoc ());
787+ if ( failed (result))
788+ return failure ();
789+ return layoutInterface ->verifyLayoutsAreEqual (
790+ dstTy. getShape (), inferredDstEnc, dstEnc, getLoc ());
805791}
806792
807793// -- FpToFpOp --
@@ -1092,11 +1078,10 @@ void JoinOp::build(OpBuilder &builder, OperationState &state, Value lhs,
10921078 Attribute srcEnc = lhsTy.getEncoding ();
10931079 Attribute retEnc;
10941080 if (srcEnc) {
1095- if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect ())
1096- ->inferDefaultJoinOpEncoding (srcEnc, retEnc, lhsTy.getShape (),
1097- /* loc=*/ std::nullopt )
1098- .failed ()) {
1099- assert (false && " failed to infer join encoding" );
1081+ if (failed (cast<DialectInferLayoutInterface>(&srcEnc.getDialect ())
1082+ ->inferDefaultJoinOpEncoding (
1083+ srcEnc, retEnc, lhsTy.getShape (), state.location ))) {
1084+ llvm_unreachable (" failed to infer join encoding" );
11001085 }
11011086 }
11021087 auto retTy = RankedTensorType::get (retShape, lhsTy.getElementType (), retEnc);
0 commit comments