Skip to content

Commit 86f3d20

Browse files
talumbaucopybara-github
authored andcommitted
Add custom shlo lowering for the einsum case used in dot product attention
PiperOrigin-RevId: 718888113
1 parent 6b0713b commit 86f3d20

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# Common utility functions for data loading etc.
16+
from dataclasses import dataclass
17+
import glob
18+
import os
19+
from typing import Sequence
20+
from ai_edge_torch.odml_torch import lowerings
21+
from jax._src.lib.mlir import ir
22+
from jax._src.lib.mlir.dialects import hlo as stablehlo
23+
import torch
24+
25+
26+
# Use torch.library.custom_op to define a new custom operator.
27+
@torch.library.custom_op("ai_edge_torch::bmm_4d", mutates_args=())
28+
def bmm_4d(
29+
lhs: torch.Tensor,
30+
rhs: torch.Tensor,
31+
) -> torch.Tensor:
32+
if not (lhs.ndim == 4 and rhs.ndim == 4):
33+
raise ValueError("bmm_4d requires LHS and RHS have rank 4.")
34+
d0_can_bcast = lhs.shape[0] == rhs.shape[0] or lhs.shape[0] == 1 or rhs.shape[0] == 1
35+
d1_can_bcast = lhs.shape[1] == rhs.shape[1] or lhs.shape[1] == 1 or rhs.shape[1] == 1
36+
if not (d0_can_bcast and d1_can_bcast):
37+
raise ValueError("bmm_4d requires that dimensions 0 and 1 can broadcast.")
38+
39+
if not lhs.shape[-1] == rhs.shape[-1]:
40+
raise ValueError("bmm_4d requires LHS and RHS have same last dimension.")
41+
42+
return torch.einsum("abcd,abed->abce", lhs, rhs)
43+
44+
45+
# Use register_fake to add a ``FakeTensor`` kernel for the operator
46+
@bmm_4d.register_fake
47+
def _(lhs, rhs):
48+
return torch.einsum("abcd,abed->abce", lhs, rhs)
49+
50+
51+
@lowerings.lower(torch.ops.ai_edge_torch.bmm_4d)
52+
def _bmm_4d_lower(
53+
lctx,
54+
lhs: ir.Value,
55+
rhs: ir.Value,
56+
):
57+
dot_dnums = stablehlo.DotDimensionNumbers.get(
58+
lhs_batching_dimensions=[0, 1],
59+
rhs_batching_dimensions=[0, 1],
60+
lhs_contracting_dimensions=(3,),
61+
rhs_contracting_dimensions=(3,),
62+
)
63+
return stablehlo.dot_general(
64+
ir.RankedTensorType.get(
65+
(
66+
lhs.type.shape[0],
67+
lhs.type.shape[1],
68+
lhs.type.shape[2],
69+
rhs.type.shape[2],
70+
),
71+
lhs.type.element_type,
72+
),
73+
lhs,
74+
rhs,
75+
dot_dnums,
76+
)

0 commit comments

Comments
 (0)