@@ -2104,20 +2104,29 @@ LogicalResult spirv::GLDistanceOp::verify() {
21042104 auto p1Type = getP1 ().getType ();
21052105 auto resultType = getResult ().getType ();
21062106
2107- auto p0VectorType = p0Type.dyn_cast <VectorType>();
2108- auto p1VectorType = p1Type.dyn_cast <VectorType>();
2109- if (!p0VectorType || !p1VectorType)
2110- return emitOpError (" operands must be vectors" );
2107+ auto getFloatType = [](Type type) -> FloatType {
2108+ if (auto vectorType = llvm::dyn_cast<VectorType>(type))
2109+ return llvm::dyn_cast<FloatType>(vectorType.getElementType ());
2110+ return llvm::dyn_cast<FloatType>(type);
2111+ };
21112112
2112- if (p0VectorType.getShape () != p1VectorType.getShape ())
2113- return emitOpError (" operands must have same shape" );
2113+ FloatType p0FloatType = getFloatType (p0Type);
2114+ FloatType p1FloatType = getFloatType (p1Type);
2115+ FloatType resultFloatType = llvm::dyn_cast<FloatType>(resultType);
21142116
2115- if (!resultType. isa <FloatType>() )
2116- return emitOpError (" result must be scalar float" );
2117+ if (!p0FloatType || !p1FloatType || !resultFloatType )
2118+ return emitOpError (" operands and result must be float scalar or vector of float" );
21172119
2118- if (p0VectorType.getElementType () != resultType ||
2119- p1VectorType.getElementType () != resultType)
2120- return emitOpError (" operand vector elements must match result type" );
2120+ if (p0FloatType != resultFloatType || p1FloatType != resultFloatType)
2121+ return emitOpError (" operand and result element types must match" );
2122+
2123+ if (auto p0Vec = llvm::dyn_cast<VectorType>(p0Type)) {
2124+ if (!llvm::dyn_cast<VectorType>(p1Type) ||
2125+ p0Vec.getShape () != llvm::dyn_cast<VectorType>(p1Type).getShape ())
2126+ return emitOpError (" vector operands must have same shape" );
2127+ } else if (llvm::isa<VectorType>(p1Type)) {
2128+ return emitOpError (" expected both operands to be scalars or both to be vectors" );
2129+ }
21212130
21222131 return success ();
21232132}
0 commit comments