@@ -23,7 +23,7 @@ namespace mlir::water {
2323
2424static bool isSupportedBroadcastType (Type type) {
2525 if (auto integerType = llvm::dyn_cast<IntegerType>(type))
26- return llvm::is_contained ({8 , 16 , 32 , 64 }, (int )integerType.getWidth ());
26+ return llvm::is_contained ({16 , 32 , 64 }, (int )integerType.getWidth ());
2727
2828 if (isa<IndexType, FloatType>(type))
2929 return true ;
@@ -49,28 +49,29 @@ struct InsertBroadcastsPass
4949 // Collect operations that need broadcasts.
5050 SmallVector<Value> insertsNeeded;
5151
52+ auto isUniform = [&](Value value) -> bool {
53+ return water::isUniform (value, solver);
54+ };
55+ auto isNonUniform = [&](Value value) -> bool {
56+ return !water::isUniform (value, solver);
57+ };
58+
5259 op->walk ([&](Operation *currentOp) {
5360 if (isa<gpu::SubgroupBroadcastOp>(currentOp))
5461 return ;
5562
5663 // Check if any operand is non-uniform.
57- bool hasNonUniformOperand = false ;
58- for (Value operand : currentOp->getOperands ()) {
59- if (!water::isUniform (operand, solver)) {
60- hasNonUniformOperand = true ;
61- break ;
62- }
63- }
64-
65- if (!hasNonUniformOperand)
64+ if (!llvm::any_of (currentOp->getOperands (), isNonUniform))
6665 return ;
6766
6867 for (Value result : currentOp->getResults ()) {
69- if (!water::isUniform (result, solver))
68+ if (!isSupportedBroadcastType (result.getType ()))
69+ continue ;
70+
71+ if (!isUniform (result))
7072 continue ;
7173
72- if (isSupportedBroadcastType (result.getType ()))
73- insertsNeeded.push_back (result);
74+ insertsNeeded.push_back (result);
7475 }
7576 });
7677
0 commit comments