Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,

// Reshape of a constant can be replaced with a new constant.
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
return elements.reshape(
cast<ShapedType>(reshapeOp.getResult().getType()).getShape());

// Fold if the producer reshape source has the same shape with at most 1
// dynamic dimension.
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/IR/BuiltinAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,9 @@ class DenseElementsAttr : public Attribute {
//===--------------------------------------------------------------------===//

/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
DenseElementsAttr reshape(ShapedType newType);
/// attribute, but has been reshaped to 'newShape'. The new shape must have
/// the same total number of elements.
DenseElementsAttr reshape(ArrayRef<int64_t> newShape);

/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but with a different shape for a splat type. The new type must
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/CAPI/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,9 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,

MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
MlirType shapedType) {
return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr))
.reshape(llvm::cast<ShapedType>(unwrap(shapedType))));
return wrap(
llvm::cast<DenseElementsAttr>(unwrap(attr))
.reshape(llvm::cast<ShapedType>(unwrap(shapedType)).getShape()));
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ struct ConstantCompositeOpPattern final
if (isa<RankedTensorType>(srcType)) {
dstAttrType = RankedTensorType::get(srcType.getNumElements(),
srcType.getElementType());
dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
dstElementsAttr = dstElementsAttr.reshape(dstAttrType.getShape());
} else {
// TODO: add support for large vectors.
return failure();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
return {};

return operand.reshape(
llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
llvm::cast<ShapedType>(operand.getType()).clone(shapeVec).getShape());
}

return {};
Expand Down Expand Up @@ -1546,7 +1546,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
if (input.isSplat() && resultTy.hasStaticShape() &&
input.getType().getElementType() == resultTy.getElementType())
return input.reshape(resultTy);
return input.reshape(resultTy.getShape());
}

// Transpose is not the identity transpose.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
RankedTensorType::get(newShape, oldType.getElementType());

if (input.isSplat()) {
return input.reshape(newType);
return input.reshape(newType.getShape());
}

auto rawData = input.getRawData();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6000,7 +6000,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
return splatAttr.reshape(getType());
return splatAttr.reshape(getType().getShape());

// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
Expand Down Expand Up @@ -6346,7 +6346,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
if (auto splat =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
return splat.reshape(getResultVectorType());
return splat.reshape(getResultVectorType().getShape());

// Eliminate poison transpose ops.
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
loc,
"Cannot linearize a constant scalable vector that's not a splat");

return dstElementsAttr.reshape(resType);
return dstElementsAttr.reshape(resType.getShape());
}

if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1241,8 +1241,10 @@ ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
DenseElementsAttr DenseElementsAttr::reshape(ArrayRef<int64_t> newShape) {

ShapedType curType = getType();
auto newType = curType.cloneWith(newShape, curType.getElementType());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if result type conversion is intended then we should keep the element type of converted result type and use for subsequent lowering. Clone with newType.getElementType instead of curType.getElementType.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise the converted information like fp8->i8 will lost.
For example,

func.func @canonicalize_extract_shapecast_different_element_type() -> vector<1x192xf8E4M3FN> {
      %0 = arith.constant dense<1.000000e+00> : vector<192xf8E4M3FN>
      %1 = vector.shape_cast %0 : vector<192xf8E4M3FN> to vector<1x192xf8E4M3FN>
      return %1 : vector<1x192xf8E4M3FN>
}

will be converted to as below after -convert-to-llvm -canonicalize="test-convergence"

module {
  llvm.func @canonicalize_extract_shapecast_different_element_type() -> !llvm.array<1 x vector<192xf8E4M3FN>> {
    %0 = llvm.mlir.constant(dense<1.000000e+00> : vector<1x192xf8E4M3FN>) : !llvm.array<1 x vector<192xf8E4M3FN>>
    llvm.return %0 : !llvm.array<1 x vector<192xf8E4M3FN>>
  }
}

However, it should be

module {
  llvm.func @canonicalize_extract_shapecast_different_element_type() -> !llvm.array<1 x vector<192xi8>> {
    %0 = llvm.mlir.constant(dense<1.000000e+00> : vector<1x192xf8E4M3FN>) : !llvm.array<1 x vector<192xi8>>
    llvm.return %0 : !llvm.array<1 x vector<192xi8>>
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MengmSun -- I just tried your example above on this branch with mlir-opt -convert-to-llvm --canonicalize test.mlir and it gives

llvm.func @canonicalize_extract_shapecast_different_element_type() -> !llvm.array<1 x vector<192xi8>> {
  %0 = llvm.mlir.constant(dense<1.000000e+00> : vector<1x192xf8E4M3FN>) : !llvm.array<1 x vector<192xi8>>
  llvm.return %0 : !llvm.array<1 x vector<192xi8>>
}

which I think is what we want

if (curType == newType)
return *this;

Expand Down
Loading