@@ -1528,6 +1528,105 @@ func.func @contract(%arg0: memref<2x3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: m
15281528
15291529// -----
15301530
1531+ func.func @contract_matmul_bcast_a (%arg0: memref <5 xf32 >, %arg1: memref <5 x7 xf32 >, %arg2: memref <3 x7 xf32 >) {
1532+ linalg.contract index ing_maps = [
1533+ affine_map <(d0 , d1 , d2 ) -> (d2 )>,
1534+ affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>,
1535+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
1536+ ]
1537+ ins (%arg0 , %arg1 : memref <5 xf32 >, memref <5 x7 xf32 >) outs (%arg2: memref <3 x7 xf32 >)
1538+ return
1539+ }
1540+
1541+ // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1542+ // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
1543+ // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1544+ // CHECK-LABEL: func @contract_matmul_bcast_a
1545+ // CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1546+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
1547+ // CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
1548+
1549+ // -----
1550+
1551+ func.func @contract_matmul_bcast_b (%arg0: memref <3 x5 xf32 >, %arg1: memref <5 xf32 >, %arg2: memref <3 x7 xf32 >) {
1552+ linalg.contract index ing_maps = [
1553+ affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>,
1554+ affine_map <(d0 , d1 , d2 ) -> (d2 )>,
1555+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
1556+ ]
1557+ ins (%arg0 , %arg1 : memref <3 x5 xf32 >, memref <5 xf32 >) outs (%arg2: memref <3 x7 xf32 >)
1558+ return
1559+ }
1560+
1561+ // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
1562+ // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1563+ // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1564+ // CHECK-LABEL: func @contract_matmul_bcast_b
1565+ // CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1566+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
1567+ // CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
1568+
1569+ // -----
1570+
1571+ func.func @contract_matmul_bcast_a_b (%arg0: memref <5 xf32 >, %arg1: memref <5 xf32 >, %arg2: memref <3 x7 xf32 >) {
1572+ linalg.contract index ing_maps = [
1573+ affine_map <(d0 , d1 , d2 ) -> (d2 )>,
1574+ affine_map <(d0 , d1 , d2 ) -> (d2 )>,
1575+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
1576+ ]
1577+ ins (%arg0 , %arg1 : memref <5 xf32 >, memref <5 xf32 >) outs (%arg2: memref <3 x7 xf32 >)
1578+ return
1579+ }
1580+
1581+ // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1582+ // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1583+ // CHECK-LABEL: func.func @contract_matmul_bcast_a_b
1584+ // CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
1585+ // CHECK: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5xf32>)
1586+ // CHECK: outs(%{{.+}} : memref<3x7xf32>)
1587+
1588+ // -----
1589+
1590+ func.func @contract_matmul_bcast_a_transpose_b (%arg0: memref <5 xf32 >, %arg1: memref <7 x5 xf32 >, %arg2: memref <3 x7 xf32 >) {
1591+ linalg.contract index ing_maps = [
1592+ affine_map <(d0 , d1 , d2 ) -> (d2 )>,
1593+ affine_map <(d0 , d1 , d2 ) -> (d1 , d2 )>,
1594+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
1595+ ]
1596+ ins (%arg0 , %arg1 : memref <5 xf32 >, memref <7 x5 xf32 >) outs (%arg2: memref <3 x7 xf32 >)
1597+ return
1598+ }
1599+
1600+ // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1601+ // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
1602+ // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1603+ // CHECK-LABEL: func.func @contract_matmul_bcast_a_transpose_b
1604+ // CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1605+ // CHECK: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<7x5xf32>)
1606+ // CHECK: outs(%{{.+}} : memref<3x7xf32>)
1607+
1608+ // -----
1609+
1610+ func.func @contract_matmul_bcast_b_transpose_a (%arg0: memref <5 x3 xf32 >, %arg1: memref <5 xf32 >, %arg2: memref <3 x7 xf32 >) {
1611+ linalg.contract index ing_maps = [
1612+ affine_map <(d0 , d1 , d2 ) -> (d2 , d0 )>,
1613+ affine_map <(d0 , d1 , d2 ) -> (d2 )>,
1614+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
1615+ ]
1616+ ins (%arg0 , %arg1 : memref <5 x3 xf32 >, memref <5 xf32 >) outs (%arg2: memref <3 x7 xf32 >)
1617+ return
1618+ }
1619+
1620+ // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
1621+ // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1622+ // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1623+ // CHECK-LABEL: func.func @contract_matmul_bcast_b_transpose_a
1624+ // CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1625+ // CHECK: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5xf32>)
1626+ // CHECK: outs(%{{.+}} : memref<3x7xf32>)
1627+
1628+ // -----
1629+
15311630// CHECK-LABEL: func @mmt4d
15321631func.func @mmt4d (%A: tensor <10 x32 x8 x1 xf32 >, %B: tensor <80 x32 x4 x1 xf32 >, %C: tensor <10 x80 x8 x4 xf32 >) -> tensor <10 x80 x8 x4 xf32 > {
15331632 // CHECK: %{{.+}} = linalg.mmt4d
0 commit comments