diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 7afd6e9b25b77..1d3436c99a63c 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1061,12 +1061,30 @@ LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR( << operands[2]; } - unsigned rows = getConstantInt(operands[3]).getInt(); - unsigned columns = getConstantInt(operands[4]).getInt(); + IntegerAttr rowsAttr = getConstantInt(operands[3]); + IntegerAttr columnsAttr = getConstantInt(operands[4]); + IntegerAttr useAttr = getConstantInt(operands[5]); + + if (!rowsAttr) + return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references " + "undefined constant ") + << operands[3]; + + if (!columnsAttr) + return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` " + "references undefined constant ") + << operands[4]; + + if (!useAttr) + return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references " + "undefined constant ") + << operands[5]; + + unsigned rows = rowsAttr.getInt(); + unsigned columns = columnsAttr.getInt(); std::optional use = - spirv::symbolizeCooperativeMatrixUseKHR( - getConstantInt(operands[5]).getInt()); + spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt()); if (!use) { return emitError( unknownLoc,