@@ -75,8 +75,10 @@ absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp,
7575 return absl::InvalidArgumentError (
7676 " Scatter dimension numbers are not valid for a diagonal tensor." );
7777
78- if (auto iotaOp = dyn_cast<stablehlo::IotaOp>(indices.getDefiningOp ())) {
79- if (iotaOp.getIotaDimension () == 0 ) {
78+ auto isIotaLikeTensor = detectIotaLikeTensor (indices);
79+ if (isIotaLikeTensor) {
80+ auto iotaLikeTensor = isIotaLikeTensor.value ();
81+ if (iotaLikeTensor.dimension == 0 && iotaLikeTensor.start == 0 ) {
8082 *outUpdates = updates;
8183 return absl::OkStatus ();
8284 }
@@ -85,5 +87,174 @@ absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp,
8587 return absl::InvalidArgumentError (" Not a diagonal tensor." );
8688}
8789
90+ std::optional<IotaLikeTensor> detectIotaLikeTensor (mlir::Value tensor) {
91+ if (!tensor)
92+ return std::nullopt ;
93+
94+ auto elemType =
95+ cast<mlir::RankedTensorType>(tensor.getType ()).getElementType ();
96+ if (!isa<mlir::IntegerType>(elemType))
97+ return std::nullopt ;
98+
99+ struct ChainItem {
100+ mlir::Operation *op;
101+ int64_t offset; // only populated for AddOp/SubtractOp
102+ };
103+
104+ // Build a chain of operations from startOp to the base case
105+ SmallVector<ChainItem> chain;
106+ llvm::DenseSet<mlir::Operation *> visited;
107+ mlir::Operation *currentOp = tensor.getDefiningOp ();
108+
109+ // Traverse to find base case
110+ while (currentOp && !visited.contains (currentOp)) {
111+ visited.insert (currentOp);
112+
113+ // check if we found a base case
114+ if (isa<stablehlo::IotaOp, stablehlo::ConstantOp>(currentOp)) {
115+ chain.push_back ({currentOp, 0 });
116+ break ;
117+ }
118+
119+ // navigate to the next op. If any unsupported intermediate op is found,
120+ // then return std::nullopt
121+ Operation *nextOp;
122+
123+ // TODO: we might want to support broadcast_in_dim / insert_dims / drop_dims
124+ // as well
125+ if (isa<stablehlo::TransposeOp>(currentOp)) {
126+ chain.push_back ({currentOp, 0 });
127+ nextOp = currentOp->getOperand (0 ).getDefiningOp ();
128+ } else if (auto convertOp = dyn_cast<stablehlo::ConvertOp>(currentOp)) {
129+ // if operand of convertOp is not a integer, then return std::nullopt
130+ if (!isa<mlir::IntegerType>(
131+ cast<TensorType>(convertOp.getOperand ().getType ())
132+ .getElementType ()))
133+ return std::nullopt ;
134+ chain.push_back ({currentOp, 0 });
135+ nextOp = convertOp.getOperand ().getDefiningOp ();
136+ } else if (auto addOp = dyn_cast<stablehlo::AddOp>(currentOp)) {
137+ APInt offsetVal;
138+ if (matchPattern (addOp.getRhs (), m_ConstantInt (&offsetVal))) {
139+ chain.push_back ({currentOp, offsetVal.getSExtValue ()});
140+ nextOp = addOp.getLhs ().getDefiningOp ();
141+ } else if (matchPattern (addOp.getLhs (), m_ConstantInt (&offsetVal))) {
142+ chain.push_back ({currentOp, offsetVal.getSExtValue ()});
143+ nextOp = addOp.getRhs ().getDefiningOp ();
144+ } else {
145+ return std::nullopt ;
146+ }
147+ } else if (auto subOp = dyn_cast<stablehlo::SubtractOp>(currentOp)) {
148+ APInt offsetVal;
149+ if (matchPattern (subOp.getRhs (), m_ConstantInt (&offsetVal))) {
150+ chain.push_back ({currentOp, -offsetVal.getSExtValue ()});
151+ nextOp = subOp.getLhs ().getDefiningOp ();
152+ } else {
153+ return std::nullopt ;
154+ }
155+ } else { // unsupported op
156+ return std::nullopt ;
157+ }
158+
159+ currentOp = nextOp;
160+ }
161+
162+ if (chain.empty ())
163+ return std::nullopt ;
164+
165+ // process the base case
166+ IotaLikeTensor result;
167+ if (auto iotaOp = dyn_cast<stablehlo::IotaOp>(chain.back ().op )) {
168+ auto iotaType = cast<RankedTensorType>(iotaOp.getResult ().getType ());
169+ auto iotaDim = static_cast <int64_t >(iotaOp.getIotaDimension ());
170+ result = IotaLikeTensor{0 , iotaType.getShape ()[iotaDim], iotaDim, iotaType};
171+ } else if (auto constantOp =
172+ dyn_cast<stablehlo::ConstantOp>(chain.back ().op )) {
173+ auto denseAttr = cast<DenseElementsAttr>(constantOp.getValue ());
174+ auto constType = cast<RankedTensorType>(constantOp.getResult ().getType ());
175+ auto shape = constType.getShape ();
176+
177+ if (denseAttr.isSplat ())
178+ return std::nullopt ;
179+
180+ // Calculate strides for indexing
181+ SmallVector<int64_t > strides (constType.getRank (), 1 );
182+ for (int64_t i = constType.getRank () - 2 ; i >= 0 ; --i) {
183+ strides[i] = strides[i + 1 ] * shape[i + 1 ];
184+ }
185+
186+ bool isIotaLike = false ;
187+ auto denseAttrValues = denseAttr.getValues <APInt>();
188+
189+ for (int64_t dim = 0 ; dim < constType.getRank (); dim++) {
190+ bool isIotaAlongDim = true ;
191+ std::optional<int64_t > detectedStart;
192+
193+ SmallVector<int64_t > indices (constType.getRank (), 0 );
194+ int64_t numElements = constType.getNumElements ();
195+
196+ for (int64_t idx = 0 ; idx < numElements && isIotaAlongDim; idx++) {
197+ int64_t temp = idx;
198+ // linear to cartesian indexing
199+ for (int64_t d = 0 ; d < constType.getRank (); d++) {
200+ indices[d] = temp / strides[d];
201+ temp = temp % strides[d];
202+ }
203+
204+ int64_t actualValue = denseAttrValues[idx].getSExtValue ();
205+
206+ if (!detectedStart) {
207+ detectedStart = actualValue;
208+ }
209+
210+ int64_t expectedValue = detectedStart.value () + indices[dim];
211+ if (actualValue != expectedValue) {
212+ isIotaAlongDim = false ;
213+ break ;
214+ }
215+ }
216+
217+ if (isIotaAlongDim && detectedStart) {
218+ isIotaLike = true ;
219+ result =
220+ IotaLikeTensor{detectedStart.value (),
221+ detectedStart.value () + shape[dim], dim, constType};
222+ break ;
223+ }
224+ }
225+
226+ if (!isIotaLike)
227+ return std::nullopt ;
228+ } else {
229+ return std::nullopt ;
230+ }
231+
232+ // traverse the chain in reverse order
233+ for (int64_t i = chain.size () - 2 ; i >= 0 ; i--) {
234+ auto item = chain[i];
235+
236+ if (isa<stablehlo::ConvertOp>(item.op )) {
237+ continue ;
238+ } else if (auto transposeOp = dyn_cast<stablehlo::TransposeOp>(item.op )) {
239+ auto permutation = transposeOp.getPermutation ();
240+ for (int64_t idx = 0 ; idx < permutation.size (); idx++) {
241+ if (permutation[idx] == result.dimension ) {
242+ result.dimension = idx;
243+ break ;
244+ }
245+ }
246+ continue ;
247+ } else if (isa<stablehlo::AddOp, stablehlo::SubtractOp>(item.op )) {
248+ result.start += item.offset ;
249+ continue ;
250+ }
251+
252+ assert (false && " reached unreachable case..." );
253+ }
254+
255+ result.tensorType = cast<RankedTensorType>(tensor.getType ());
256+ return result;
257+ }
258+
88259} // namespace enzyme
89260} // namespace mlir
0 commit comments