Skip to content

[WIP][mlir] DenseElementsAttr::reshape(arg): make arg a shape, not a type #149947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,

// Reshape of a constant can be replaced with a new constant.
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
return elements.reshape(
cast<ShapedType>(reshapeOp.getResult().getType()).getShape());

// Fold if the producer reshape source has the same shape with at most 1
// dynamic dimension.
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/IR/BuiltinAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,9 @@ class DenseElementsAttr : public Attribute {
//===--------------------------------------------------------------------===//

/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
DenseElementsAttr reshape(ShapedType newType);
/// attribute, but has been reshaped to 'newShape'. The new shape must have
/// the same total number of elements.
DenseElementsAttr reshape(ArrayRef<int64_t> newShape);

/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but with a different shape for a splat type. The new type must
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/CAPI/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,9 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,

MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
MlirType shapedType) {
return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr))
.reshape(llvm::cast<ShapedType>(unwrap(shapedType))));
return wrap(
llvm::cast<DenseElementsAttr>(unwrap(attr))
.reshape(llvm::cast<ShapedType>(unwrap(shapedType)).getShape()));
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ struct ConstantCompositeOpPattern final
if (isa<RankedTensorType>(srcType)) {
dstAttrType = RankedTensorType::get(srcType.getNumElements(),
srcType.getElementType());
dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
dstElementsAttr = dstElementsAttr.reshape(dstAttrType.getShape());
} else {
// TODO: add support for large vectors.
return failure();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
return {};

return operand.reshape(
llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
llvm::cast<ShapedType>(operand.getType()).clone(shapeVec).getShape());
}

return {};
Expand Down Expand Up @@ -1546,7 +1546,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
if (input.isSplat() && resultTy.hasStaticShape() &&
input.getType().getElementType() == resultTy.getElementType())
return input.reshape(resultTy);
return input.reshape(resultTy.getShape());
}

// Transpose is not the identity transpose.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
RankedTensorType::get(newShape, oldType.getElementType());

if (input.isSplat()) {
return input.reshape(newType);
return input.reshape(newType.getShape());
}

auto rawData = input.getRawData();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6000,7 +6000,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
return splatAttr.reshape(getType());
return splatAttr.reshape(getType().getShape());

// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
Expand Down Expand Up @@ -6346,7 +6346,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
if (auto splat =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
return splat.reshape(getResultVectorType());
return splat.reshape(getResultVectorType().getShape());

// Eliminate poison transpose ops.
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
loc,
"Cannot linearize a constant scalable vector that's not a splat");

return dstElementsAttr.reshape(resType);
return dstElementsAttr.reshape(resType.getShape());
}

if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1241,8 +1241,10 @@ ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
DenseElementsAttr DenseElementsAttr::reshape(ArrayRef<int64_t> newShape) {

ShapedType curType = getType();
auto newType = curType.cloneWith(newShape, curType.getElementType());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if result type conversion is intended then we should keep the element type of converted result type and use for subsequent lowering. Clone with newType.getElementType instead of curType.getElementType.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise the converted information like fp8->i8 will lost.
For example,

func.func @canonicalize_extract_shapecast_different_element_type() -> vector<1x192xf8E4M3FN> {
      %0 = arith.constant dense<1.000000e+00> : vector<192xf8E4M3FN>
      %1 = vector.shape_cast %0 : vector<192xf8E4M3FN> to vector<1x192xf8E4M3FN>
      return %1 : vector<1x192xf8E4M3FN>
}

will be converted to as below after -convert-to-llvm -canonicalize="test-convergence"

module {
  llvm.func @canonicalize_extract_shapecast_different_element_type() -> !llvm.array<1 x vector<192xf8E4M3FN>> {
    %0 = llvm.mlir.constant(dense<1.000000e+00> : vector<1x192xf8E4M3FN>) : !llvm.array<1 x vector<192xf8E4M3FN>>
    llvm.return %0 : !llvm.array<1 x vector<192xf8E4M3FN>>
  }
}

However, it should be

module {
  llvm.func @canonicalize_extract_shapecast_different_element_type() -> !llvm.array<1 x vector<192xi8>> {
    %0 = llvm.mlir.constant(dense<1.000000e+00> : vector<1x192xf8E4M3FN>) : !llvm.array<1 x vector<192xi8>>
    llvm.return %0 : !llvm.array<1 x vector<192xi8>>
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MengmSun -- I just tried your example above on this branch with mlir-opt -convert-to-llvm --canonicalize test.mlir and it gives

llvm.func @canonicalize_extract_shapecast_different_element_type() -> !llvm.array<1 x vector<192xi8>> {
  %0 = llvm.mlir.constant(dense<1.000000e+00> : vector<1x192xf8E4M3FN>) : !llvm.array<1 x vector<192xi8>>
  llvm.return %0 : !llvm.array<1 x vector<192xi8>>
}

which I think is what we want

if (curType == newType)
return *this;

Expand Down
Loading