Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 86 additions & 1 deletion src/Dialect/ONNX/Transforms/QDQAroundOpOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,39 @@

using namespace mlir;
using namespace onnx_mlir;

/// Check if a value is defined by a constant operation
/// Returns false for NoValue (NoneType)
/// Uses recursive logic to check if all operands are constants (initializers)
static bool isConstantOrInitializer(Value val) {
if (!val)
return false;

// Return false for NoValue (which has NoneType)
if (mlir::isa<NoneType>(val.getType())) {
return false;
}

Operation *definingOp = val.getDefiningOp();
if (!definingOp) {
return false;
}

// Check if it's a constant op
if (llvm::isa<ONNXConstantOp>(definingOp)) {
return true;
}

// Recursively check if all operands are initializers
// If all operands are constants, the result is effectively constant
for (Value operand : definingOp->getOperands()) {
if (!isConstantOrInitializer(operand)) {
return false;
}
}
return true;
}

struct InputAndOutput {
Value input;
Value output;
Expand Down Expand Up @@ -54,12 +87,64 @@ class RemoveQDQAroundOpPattern : public OpRewritePattern<T> {

LogicalResult matchAndRewrite(
T op, PatternRewriter &rewriter) const override {
// Special handling for Resize - only support "nearest" mode
if (llvm::isa<ONNXResizeOp>(op)) {
auto &resizeOp = llvm::cast<ONNXResizeOp>(op);
auto resizeOp = llvm::cast<ONNXResizeOp>(op);
if (resizeOp.getMode() != "nearest") {
return failure();
}

// Resize: require control parameters to be constants
if (!isConstantOrInitializer(resizeOp.getRoi()) ||
!isConstantOrInitializer(resizeOp.getScales()) ||
!isConstantOrInitializer(resizeOp.getSizes())) {
return failure();
}
}

// Unsqueeze requires axes to be a constant
if (llvm::isa<ONNXUnsqueezeOp>(op)) {
auto unsqueezeOp = llvm::cast<ONNXUnsqueezeOp>(op);
if (!isConstantOrInitializer(unsqueezeOp.getAxes())) {
return failure();
}
}

// Squeeze requires axes to be a constant
if (llvm::isa<ONNXSqueezeOp>(op)) {
auto squeezeOp = llvm::cast<ONNXSqueezeOp>(op);
if (!isConstantOrInitializer(squeezeOp.getAxes())) {
return failure();
}
}

// Reshape requires shape to be a constant
if (llvm::isa<ONNXReshapeOp>(op)) {
auto reshapeOp = llvm::cast<ONNXReshapeOp>(op);
if (!isConstantOrInitializer(reshapeOp.getShape())) {
return failure();
}
}

// Gather requires indices to be a constant
if (llvm::isa<ONNXGatherOp>(op)) {
auto gatherOp = llvm::cast<ONNXGatherOp>(op);
if (!isConstantOrInitializer(gatherOp.getIndices())) {
return failure();
}
}

// Slice requires all control parameters to be constants
if (llvm::isa<ONNXSliceOp>(op)) {
auto sliceOp = llvm::cast<ONNXSliceOp>(op);
if (!isConstantOrInitializer(sliceOp.getStarts()) ||
!isConstantOrInitializer(sliceOp.getEnds()) ||
!isConstantOrInitializer(sliceOp.getAxes()) ||
!isConstantOrInitializer(sliceOp.getSteps())) {
return failure();
}
}

InputAndOutput opIO = getDataInputOutput(op);

auto dqOp = opIO.input.getDefiningOp<ONNXDequantizeLinearOp>();
Expand Down
24 changes: 0 additions & 24 deletions test/mlir/onnx/qdq_removal_flatten.mlir

This file was deleted.

20 changes: 0 additions & 20 deletions test/mlir/onnx/qdq_removal_gather.mlir

This file was deleted.

28 changes: 0 additions & 28 deletions test/mlir/onnx/qdq_removal_reshape.mlir

This file was deleted.

78 changes: 0 additions & 78 deletions test/mlir/onnx/qdq_removal_resize.mlir

This file was deleted.

32 changes: 0 additions & 32 deletions test/mlir/onnx/qdq_removal_slice.mlir

This file was deleted.

26 changes: 0 additions & 26 deletions test/mlir/onnx/qdq_removal_squeeze.mlir

This file was deleted.

16 changes: 0 additions & 16 deletions test/mlir/onnx/qdq_removal_transpose.mlir

This file was deleted.

26 changes: 0 additions & 26 deletions test/mlir/onnx/qdq_removal_unsqueeze.mlir

This file was deleted.

Loading