55
66# RUN: %PYTHON %s | FileCheck %s
77
8- # test_sparse_feature_scaling fails to lower, and we don't need sparsity
9- # look into that later.
10- # UNSUPPORTED: true
11-
128from typing import Any , Callable , Optional , Tuple , Dict
139
1410import torch
@@ -220,25 +216,25 @@ def forward(self, x, v):
220216 print ("torch.mlir =" , res2 )
221217
222218
223- @run
219+ # @run
224220#
225- # CHECK -LABEL: test_sparse_SpMM
226- # CHECK : #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
227- # CHECK : func.func @main(
228- # CHECK -SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
229- # CHECK -SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> {
230- # CHECK : %[[R:.*]] = torch.aten.{{matmul|mm}} %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32>
231- # CHECK : return %[[R]] : !torch.vtensor<[8,8],f32>
232- # CHECK : }
221+ # C_HECK -LABEL: test_sparse_SpMM
222+ # C_HECK : #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
223+ # C_HECK : func.func @main(
224+ # C_HECK -SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
225+ # C_HECK -SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> {
226+ # C_HECK : %[[R:.*]] = torch.aten.{{matmul|mm}} %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32>
227+ # C_HECK : return %[[R]] : !torch.vtensor<[8,8],f32>
228+ # C_HECK : }
233229##
234- # CHECK : torch.sparse
235- # CHECK : tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.],
236- # CHECK -COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.],
237- # CHECK : [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}})
238- # CHECK : torch.mlir
239- # CHECK : {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.]
240- # CHECK -COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.]
241- # CHECK : [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}}
230+ # C_HECK : torch.sparse
231+ # C_HECK : tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.],
232+ # C_HECK -COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.],
233+ # C_HECK : [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}})
234+ # C_HECK : torch.mlir
235+ # C_HECK : {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.]
236+ # C_HECK -COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.]
237+ # C_HECK : [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}}
242238#
243239def test_sparse_SpMM ():
244240 class MatMulNet (torch .nn .Module ):
@@ -263,40 +259,40 @@ def forward(self, x, y):
263259 print (res2 )
264260
265261
266- @run
262+ # @run
267263#
268- # CHECK -LABEL: test_sparse_eltwise
269- # CHECK : #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }>
270- # CHECK : func.func @main(
271- # CHECK -SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> {
272- # CHECK : %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
273- # CHECK : return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
274- # CHECK : }
275- # CHECK : #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
276- # CHECK : func.func @main(
277- # CHECK -SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> {
278- # CHECK : %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
279- # CHECK : return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
280- # CHECK : }
264+ # C_HECK -LABEL: test_sparse_eltwise
265+ # C_HECK : #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }>
266+ # C_HECK : func.func @main(
267+ # C_HECK -SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> {
268+ # C_HECK : %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
269+ # C_HECK : return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
270+ # C_HECK : }
271+ # C_HECK : #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
272+ # C_HECK : func.func @main(
273+ # C_HECK -SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> {
274+ # C_HECK : %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
275+ # C_HECK : return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
276+ # C_HECK : }
281277#
282- # CHECK : torch.sparse
283- # CHECK : tensor(crow_indices=tensor([0, 2, 4, 6, 8]),
284- # CHECK : col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]),
285- # CHECK : values=tensor({{\[}}[ -1., -2.],
286- # CHECK : [ -3., -4.],
287- # CHECK : [ -5., -6.],
288- # CHECK : [ -7., -8.],
289- # CHECK : [ -9., -10.],
290- # CHECK : [-11., -12.],
291- # CHECK : [-13., -14.],
292- # CHECK : [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8,
293- # CHECK : layout=torch.sparse_csr)
294- # CHECK : torch.mlir
295- # CHECK : [0 2 4 6 8]
296- # CHECK : [0 1 0 1 0 1 0 1]
297- # CHECK : [ -1. -2. -3. -4. -5. -6. -7. -8. -9. -10. -11. -12. -13. -14.
298- # CHECK : -15. -16.]
299- # CHECK : torch.mlir.batch
278+ # C_HECK : torch.sparse
279+ # C_HECK : tensor(crow_indices=tensor([0, 2, 4, 6, 8]),
280+ # C_HECK : col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]),
281+ # C_HECK : values=tensor({{\[}}[ -1., -2.],
282+ # C_HECK : [ -3., -4.],
283+ # C_HECK : [ -5., -6.],
284+ # C_HECK : [ -7., -8.],
285+ # C_HECK : [ -9., -10.],
286+ # C_HECK : [-11., -12.],
287+ # C_HECK : [-13., -14.],
288+ # C_HECK : [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8,
289+ # C_HECK : layout=torch.sparse_csr)
290+ # C_HECK : torch.mlir
291+ # C_HECK : [0 2 4 6 8]
292+ # C_HECK : [0 1 0 1 0 1 0 1]
293+ # C_HECK : [ -1. -2. -3. -4. -5. -6. -7. -8. -9. -10. -11. -12. -13. -14.
294+ # C_HECK : -15. -16.]
295+ # C_HECK : torch.mlir.batch
300296#
301297def test_sparse_eltwise ():
302298 class EltNet (torch .nn .Module ):
@@ -439,20 +435,20 @@ def forward(self, x):
439435 print (res2 [4 ])
440436
441437
442- @run
438+ # @run
443439#
444- # CHECK -LABEL: test_sparse_network
445- # CHECK : func.func @main(
446- # CHECK -SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> {
440+ # C_HECK -LABEL: test_sparse_network
441+ # C_HECK : func.func @main(
442+ # C_HECK -SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> {
447443# ... lots of IR ...
448- # CHECK -COUNT-15: torch.aten.mul.Tensor
444+ # C_HECK -COUNT-15: torch.aten.mul.Tensor
449445# ... lots of IR ...
450- # CHECK : }
446+ # C_HECK : }
451447#
452- # CHECK : torch.sparse
453- # CHECK : tensor([ 0., 11., 9., 11., 13., 11., 10., 12.])
454- # CHECK : torch.mlir
455- # CHECK : [ 0. 11. 9. 11. 13. 11. 10. 12.]
448+ # C_HECK : torch.sparse
449+ # C_HECK : tensor([ 0., 11., 9., 11., 13., 11., 10., 12.])
450+ # C_HECK : torch.mlir
451+ # C_HECK : [ 0. 11. 9. 11. 13. 11. 10. 12.]
456452#
457453def test_sparse_network ():
458454 def spike (input ):
@@ -525,30 +521,30 @@ def forward(self, X):
525521 print (res2 )
526522
527523
528- @run
524+ # @run
529525#
530- # CHECK -LABEL: test_sparse_feature_scaling
531- # CHECK : func.func @main(
532- # CHECK -SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> {
526+ # C_HECK -LABEL: test_sparse_feature_scaling
527+ # C_HECK : func.func @main(
528+ # C_HECK -SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> {
533529# ... more IR ...
534- # CHECK : %[[D:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}"
535- # CHECK : %[[R:.*]] = torch.aten.{{matmul|mm}} %[[D]], %[[A]]
536- # CHECK return %[[R]] : !torch.vtensor<[4,4],f32>
537- # CHECK : }
530+ # C_HECK : %[[D:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}"
531+ # C_HECK : %[[R:.*]] = torch.aten.{{matmul|mm}} %[[D]], %[[A]]
532+ # C_HECK return %[[R]] : !torch.vtensor<[4,4],f32>
533+ # C_HECK : }
538534#
539- # CHECK : torch.sparse
540- # CHECK : tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889],
541- # CHECK : [0.1321, 0.2724, 0.2105, 0.3851],
542- # CHECK : [0.2478, 0.3439, 0.1898, 0.2185],
543- # CHECK : [0.0222, 0.1683, 0.2928, 0.5167]{{\]}})
535+ # C_HECK : torch.sparse
536+ # C_HECK : tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889],
537+ # C_HECK : [0.1321, 0.2724, 0.2105, 0.3851],
538+ # C_HECK : [0.2478, 0.3439, 0.1898, 0.2185],
539+ # C_HECK : [0.0222, 0.1683, 0.2928, 0.5167]{{\]}})
544540#
545541# TODO: first row looks suspect...
546542#
547- # CHECK : torch.mlir
548- # CHECK : {{\[}}[0. 0. 0. 0. ]
549- # CHECK : [0.13205223 0.27236593 0.21051763 0.38506418]
550- # CHECK : [0.24781987 0.34391665 0.18976606 0.2184974 ]
551- # CHECK : [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}}
543+ # C_HECK : torch.mlir
544+ # C_HECK : {{\[}}[0. 0. 0. 0. ]
545+ # C_HECK : [0.13205223 0.27236593 0.21051763 0.38506418]
546+ # C_HECK : [0.24781987 0.34391665 0.18976606 0.2184974 ]
547+ # C_HECK : [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}}
552548#
553549def test_sparse_feature_scaling ():
554550 class Scale (nn .Module ):
0 commit comments