Commit f7a92d3
[Torch Dialect] Decompose AtenTriuOp (#2561)
decompose like:
```
import torch
def my_triu(x, diag):
rows = torch.ops.aten.size(x, -2)
cols = torch.ops.aten.size(x, -1)
row_indices = torch.ops.aten.arange(rows).unsqueeze(1)
col_indices = torch.ops.aten.arange(cols).unsqueeze(0)
cond = torch.ops.aten.ge(
col_indices, torch.ops.aten.add(row_indices, diag))
return torch.ops.aten.where(cond, x, 0)
x = torch.rand(5, 7)
assert torch.allclose(my_triu(x, 0), torch.triu(x, 0))
assert torch.allclose(my_triu(x, 1), torch.triu(x, 1))
assert torch.allclose(my_triu(x, 2), torch.triu(x, 2))
assert torch.allclose(my_triu(x, -1), torch.triu(x, -1))
```
---------
Co-authored-by: LiuYuanqiang <[email protected]>1 parent 49fdc1a commit f7a92d3
File tree
3 files changed
+104
-0
lines changed- lib/Dialect/Torch/Transforms
- projects/pt1/python/torch_mlir_e2e_test/test_suite
3 files changed
+104
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
246 | 246 | | |
247 | 247 | | |
248 | 248 | | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
249 | 305 | | |
250 | 306 | | |
251 | 307 | | |
| |||
5817 | 5873 | | |
5818 | 5874 | | |
5819 | 5875 | | |
| 5876 | + | |
5820 | 5877 | | |
5821 | 5878 | | |
5822 | 5879 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
500 | 500 | | |
501 | 501 | | |
502 | 502 | | |
| 503 | + | |
503 | 504 | | |
504 | 505 | | |
505 | 506 | | |
| |||
Lines changed: 46 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3251 | 3251 | | |
3252 | 3252 | | |
3253 | 3253 | | |
| 3254 | + | |
| 3255 | + | |
| 3256 | + | |
| 3257 | + | |
| 3258 | + | |
| 3259 | + | |
| 3260 | + | |
| 3261 | + | |
| 3262 | + | |
| 3263 | + | |
| 3264 | + | |
| 3265 | + | |
| 3266 | + | |
| 3267 | + | |
| 3268 | + | |
| 3269 | + | |
| 3270 | + | |
| 3271 | + | |
| 3272 | + | |
| 3273 | + | |
| 3274 | + | |
| 3275 | + | |
| 3276 | + | |
| 3277 | + | |
| 3278 | + | |
| 3279 | + | |
| 3280 | + | |
| 3281 | + | |
| 3282 | + | |
| 3283 | + | |
| 3284 | + | |
| 3285 | + | |
| 3286 | + | |
| 3287 | + | |
| 3288 | + | |
| 3289 | + | |
| 3290 | + | |
| 3291 | + | |
| 3292 | + | |
| 3293 | + | |
| 3294 | + | |
| 3295 | + | |
| 3296 | + | |
| 3297 | + | |
| 3298 | + | |
| 3299 | + | |
3254 | 3300 | | |
3255 | 3301 | | |
3256 | 3302 | | |
| |||
0 commit comments