Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 129 additions & 24 deletions extension/llm/modules/test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def setUp(self):
self.num_kv_heads = 8
self.head_dim = 64
self.max_seq_len = 128
self.encoder_max_seq_len = 128
self.rope_base = 500_000
self.scale_factor = 32

Expand Down Expand Up @@ -86,16 +87,26 @@ def setUp(self):
max_seq_len=self.max_seq_len,
)
self.et_mha.load_state_dict(self.tt_mha.state_dict())

# Common inputs.
seq_len = 10
self.x = torch.randn(1, seq_len, self.embed_dim)
self.y = torch.randn(1, seq_len, self.embed_dim)
self.input_pos = torch.arange(seq_len).unsqueeze(0) # shape [1, seq_len]
seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
self.dynamic_shapes = (
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
{0: torch.export.Dim.STATIC, 1: seq_len_dim},
)
self.seq_len_dim = torch.export.Dim("seq_len", min=1, max=self.max_seq_len)
self.dynamic_shapes = {
"x": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
},
"y": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
},
"input_pos": {0: torch.export.Dim.STATIC, 1: self.seq_len_dim},
}
self.causal_mask = torch.tril(
torch.ones(
size=(self.max_seq_len, self.max_seq_len),
Expand All @@ -110,8 +121,8 @@ def test_attention_eager(self):
assert_close(et_res, tt_res)

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

et_res = self.et_mha(self.x, self.x) # Self attention.
tt_res = self.tt_mha(self.x, self.x) # Self attention.
Expand Down Expand Up @@ -144,12 +155,12 @@ def test_attention_export(self):
# Self attention.

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
(self.x, self.y),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
strict=True,
Expand All @@ -166,8 +177,8 @@ def test_attention_aoti(self):
# Self attention.

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
with torch.no_grad():
so = torch._export.aot_compile(
self.et_mha,
Expand All @@ -189,13 +200,13 @@ def test_attention_aoti(self):

def test_attention_executorch(self):
# Self attention.
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
(self.x, self.y),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
strict=True,
Expand All @@ -222,22 +233,18 @@ def test_attention_executorch(self):

def test_attention_torch_cond_eager(self):
# Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
# For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
# For the first run of MHA we provide `y` but for the second run it will be a tensor full of nan.
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

mask = self.causal_mask[self.input_pos, :]
# First run.
et_res = self.et_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.
tt_res = self.tt_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.
et_res = self.et_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)

assert_close(et_res, tt_res)

# Second run test kv cache read. Input pos is [10, 11, ..., 19]
# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)

empty_y = torch.full_like(self.x, torch.nan)
Expand All @@ -246,3 +253,101 @@ def test_attention_torch_cond_eager(self):
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)

assert_close(et_res, tt_res)

def test_attention_torch_cond_export(self):
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
mask = self.causal_mask[self.input_pos, :]
dynamic_shapes = {
**self.dynamic_shapes,
**{
"mask": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
}
},
}
with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.y),
kwargs={
"mask": mask,
"input_pos": self.input_pos,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)

# First run.
et_res = et_mha_ep.module()(self.x, self.y, mask=mask, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)

assert_close(et_res, tt_res)

# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)
empty_y = torch.full_like(self.y, torch.nan)
mask = self.causal_mask[next_input_pos, :]
et_res = et_mha_ep.module()(
self.x, empty_y, mask=mask, input_pos=next_input_pos
)
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)

assert_close(et_res, tt_res)

def test_attention_torch_cond_executorch(self):
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
mask = self.causal_mask[self.input_pos, :]
dynamic_shapes = {
**self.dynamic_shapes,
**{
"mask": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
}
},
}
with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.y),
kwargs={
"mask": mask,
"input_pos": self.input_pos,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)
et_program = to_edge(
et_mha_ep,
compile_config=EdgeCompileConfig(
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
_check_ir_validity=False,
),
).to_executorch(
config=ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
)
)

# First run.
runtime = Runtime.get()
program = runtime.load_program(et_program.buffer)
method = program.load_method("forward")
et_res = method.execute((self.x, self.y, mask, self.input_pos))
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)

assert_close(et_res[0], tt_res)

# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)
empty_y = torch.full_like(self.y, torch.nan)
mask = self.causal_mask[next_input_pos, :]
et_res = method.execute((self.x, empty_y, mask, next_input_pos))
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)

assert_close(et_res[0], tt_res)
Loading