Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,13 @@ def SPIRV_GLFMixOp :

// -----

def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [Pure]> {
def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [
Pure,
AllElementTypesMatch<["p0", "p1"]>,
TypesMatchWith<"result type must match operand element type",
"p0", "result",
"::mlir::getElementTypeOrSelf($_self)">
]> {
let summary = "Return distance between two points";

let description = [{
Expand Down Expand Up @@ -1060,6 +1066,8 @@ def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [Pure]> {
let assemblyFormat = [{
operands attr-dict `:` type($p0) `,` type($p1) `->` type($result)
}];

let hasVerifier = 0;
}

// -----
Expand Down
38 changes: 0 additions & 38 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2094,41 +2094,3 @@ LogicalResult spirv::VectorTimesScalarOp::verify() {
return emitOpError("scalar operand and result element type match");
return success();
}

//===----------------------------------------------------------------------===//
// spirv.GLDistanceOp
//===----------------------------------------------------------------------===//

LogicalResult spirv::GLDistanceOp::verify() {
auto p0Type = getP0().getType();
auto p1Type = getP1().getType();
auto resultType = getResult().getType();

auto getFloatType = [](Type type) -> FloatType {
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return llvm::dyn_cast<FloatType>(vectorType.getElementType());
return llvm::dyn_cast<FloatType>(type);
};

FloatType p0FloatType = getFloatType(p0Type);
FloatType p1FloatType = getFloatType(p1Type);
FloatType resultFloatType = llvm::dyn_cast<FloatType>(resultType);

if (!p0FloatType || !p1FloatType || !resultFloatType)
return emitOpError(
"operands and result must be float scalar or vector of float");

if (p0FloatType != resultFloatType || p1FloatType != resultFloatType)
return emitOpError("operand and result element types must match");

if (auto p0Vec = llvm::dyn_cast<VectorType>(p0Type)) {
if (!llvm::dyn_cast<VectorType>(p1Type) ||
p0Vec.getShape() != llvm::dyn_cast<VectorType>(p1Type).getShape())
return emitOpError("vector operands must have same shape");
} else if (llvm::isa<VectorType>(p1Type)) {
return emitOpError(
"expected both operands to be scalars or both to be vectors");
}

return success();
}
Loading