diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 6cd05b4bf65..3ecf0b2b4ba 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -33,6 +33,7 @@ def setUp(self): self.num_kv_heads = 8 self.head_dim = 64 self.max_seq_len = 128 + self.encoder_max_seq_len = 128 self.rope_base = 500_000 self.scale_factor = 32 @@ -86,16 +87,26 @@ def setUp(self): max_seq_len=self.max_seq_len, ) self.et_mha.load_state_dict(self.tt_mha.state_dict()) + # Common inputs. seq_len = 10 self.x = torch.randn(1, seq_len, self.embed_dim) + self.y = torch.randn(1, seq_len, self.embed_dim) self.input_pos = torch.arange(seq_len).unsqueeze(0) # shape [1, seq_len] - seq_len_dim = torch.export.Dim("seq_len", min=1, max=100) - self.dynamic_shapes = ( - {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, - {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, - {0: torch.export.Dim.STATIC, 1: seq_len_dim}, - ) + self.seq_len_dim = torch.export.Dim("seq_len", min=1, max=self.max_seq_len) + self.dynamic_shapes = { + "x": { + 0: torch.export.Dim.STATIC, + 1: self.seq_len_dim, + 2: torch.export.Dim.STATIC, + }, + "y": { + 0: torch.export.Dim.STATIC, + 1: self.seq_len_dim, + 2: torch.export.Dim.STATIC, + }, + "input_pos": {0: torch.export.Dim.STATIC, 1: self.seq_len_dim}, + } self.causal_mask = torch.tril( torch.ones( size=(self.max_seq_len, self.max_seq_len), @@ -110,8 +121,8 @@ def test_attention_eager(self): assert_close(et_res, tt_res) # test with kv cache - self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) - self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) et_res = self.et_mha(self.x, self.x) # Self attention. tt_res = self.tt_mha(self.x, self.x) # Self attention. @@ -144,12 +155,12 @@ def test_attention_export(self): # Self attention. # test with kv cache - self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) - self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) with torch.no_grad(): et_mha_ep = torch.export.export( self.et_mha, - (self.x, self.x), + (self.x, self.y), kwargs={"input_pos": self.input_pos}, dynamic_shapes=self.dynamic_shapes, strict=True, @@ -166,8 +177,8 @@ def test_attention_aoti(self): # Self attention. # test with kv cache - self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) - self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) with torch.no_grad(): so = torch._export.aot_compile( self.et_mha, @@ -189,13 +200,13 @@ def test_attention_aoti(self): def test_attention_executorch(self): # Self attention. - self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) - self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) with torch.no_grad(): et_mha_ep = torch.export.export( self.et_mha, - (self.x, self.x), + (self.x, self.y), kwargs={"input_pos": self.input_pos}, dynamic_shapes=self.dynamic_shapes, strict=True, @@ -222,22 +233,18 @@ def test_attention_executorch(self): def test_attention_torch_cond_eager(self): # Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition. - # For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan. + # For the first run of MHA we provide `y` but for the second run it will be a tensor full of nan. self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) mask = self.causal_mask[self.input_pos, :] # First run. - et_res = self.et_mha( - self.x, self.x, mask=mask, input_pos=self.input_pos - ) # Self attention with input pos. - tt_res = self.tt_mha( - self.x, self.x, mask=mask, input_pos=self.input_pos - ) # Self attention with input pos. + et_res = self.et_mha(self.x, self.y, mask=mask, input_pos=self.input_pos) + tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos) assert_close(et_res, tt_res) - # Second run test kv cache read. Input pos is [10, 11, ..., 19] + # Second run tests kv cache read. Input pos is [10, 11, ..., 19] next_input_pos = torch.arange(10, 20).unsqueeze(0) empty_y = torch.full_like(self.x, torch.nan) @@ -246,3 +253,101 @@ def test_attention_torch_cond_eager(self): tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos) assert_close(et_res, tt_res) + + def test_attention_torch_cond_export(self): + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + mask = self.causal_mask[self.input_pos, :] + dynamic_shapes = { + **self.dynamic_shapes, + **{ + "mask": { + 0: torch.export.Dim.STATIC, + 1: self.seq_len_dim, + 2: torch.export.Dim.STATIC, + } + }, + } + with torch.no_grad(): + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.y), + kwargs={ + "mask": mask, + "input_pos": self.input_pos, + }, + dynamic_shapes=dynamic_shapes, + strict=True, + ) + + # First run. + et_res = et_mha_ep.module()(self.x, self.y, mask=mask, input_pos=self.input_pos) + tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos) + + assert_close(et_res, tt_res) + + # Second run tests kv cache read. Input pos is [10, 11, ..., 19] + next_input_pos = torch.arange(10, 20).unsqueeze(0) + empty_y = torch.full_like(self.y, torch.nan) + mask = self.causal_mask[next_input_pos, :] + et_res = et_mha_ep.module()( + self.x, empty_y, mask=mask, input_pos=next_input_pos + ) + tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos) + + assert_close(et_res, tt_res) + + def test_attention_torch_cond_executorch(self): + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + mask = self.causal_mask[self.input_pos, :] + dynamic_shapes = { + **self.dynamic_shapes, + **{ + "mask": { + 0: torch.export.Dim.STATIC, + 1: self.seq_len_dim, + 2: torch.export.Dim.STATIC, + } + }, + } + with torch.no_grad(): + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.y), + kwargs={ + "mask": mask, + "input_pos": self.input_pos, + }, + dynamic_shapes=dynamic_shapes, + strict=True, + ) + et_program = to_edge( + et_mha_ep, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ).to_executorch( + config=ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + ) + + # First run. + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + et_res = method.execute((self.x, self.y, mask, self.input_pos)) + tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos) + + assert_close(et_res[0], tt_res) + + # Second run tests kv cache read. Input pos is [10, 11, ..., 19] + next_input_pos = torch.arange(10, 20).unsqueeze(0) + empty_y = torch.full_like(self.y, torch.nan) + mask = self.causal_mask[next_input_pos, :] + et_res = method.execute((self.x, empty_y, mask, next_input_pos)) + tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos) + + assert_close(et_res[0], tt_res)