diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 1660a8734e56..56f62df54a23 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -371,4 +371,13 @@ SmallVector getNontrivialTransitiveUsers(Value v) { return users; } +bool hasVectorOperandsOrResults(Operation& op) { + for (Value value : llvm::concat(op.getOperands(), op.getResults())) { + if (isa(value.getType())) { + return true; + } + } + return false; +} + } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 6def3b6b8739..8512890197e3 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -299,6 +299,8 @@ std::optional getIntConst(Value v); // results. SmallVector getNontrivialTransitiveUsers(Value v); +bool hasVectorOperandsOrResults(Operation& op); + // Return a mod b for a, b > 0, but adjusted to return b when a mod b == 0 such // that the result is strictly positive. template