Skip to content

Commit 0f2debd

Browse files
protobird-gitcopybara-github
authored andcommitted
Fix wrong property access
- property access must not be in a form of function call - fix wrong construction of prefill kv cache when mask_as_input is true PiperOrigin-RevId: 755434493
1 parent 20b321a commit 0f2debd

File tree

3 files changed

+11
-13
lines changed

3 files changed

+11
-13
lines changed

ai_edge_torch/generative/examples/gemma3/decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def create_sliding_mask(
201201
torch.zeros_like(sliding_mask_bool, dtype=torch.float),
202202
torch.full_like(
203203
sliding_mask_bool,
204-
self.config.get_causal_mask_value(),
204+
self.config.causal_mask_value,
205205
dtype=torch.float,
206206
),
207207
)
@@ -219,7 +219,7 @@ def compose_mask(
219219
mask = torch.logical_and(mask, pixel_mask)
220220
else:
221221
mask = torch.logical_or(mask, pixel_mask)
222-
mask = torch.where(mask, 0, self.config.get_causal_mask_value())
222+
mask = torch.where(mask, 0, self.config.causal_mask_value)
223223
return mask
224224

225225
def build_pixel_mask(self, image_indices: torch.Tensor):

ai_edge_torch/generative/layers/model_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,5 +251,5 @@ def block_config(self, idx: int) -> TransformerBlockConfig:
251251
return self.block_configs[idx]
252252

253253
@property
254-
def get_causal_mask_value(self) -> float:
254+
def causal_mask_value(self) -> float:
255255
return self.block_config(0).attn_config.causal_mask_value

ai_edge_torch/generative/utilities/converter.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -243,15 +243,13 @@ def _export_helper(
243243

244244
prefill_masks = None
245245
if flags.FLAGS.mask_as_input:
246-
prefill_masks = [
247-
_build_mask(
248-
flags.FLAGS.prefill_seq_lens,
249-
flags.FLAGS.kv_cache_max_len,
250-
config.get_causal_mask_value(),
251-
)
252-
]
253-
254-
if prefill_masks:
246+
prefill_masks = _build_mask(
247+
flags.FLAGS.prefill_seq_lens,
248+
flags.FLAGS.kv_cache_max_len,
249+
config.causal_mask_value,
250+
)
251+
if not isinstance(prefill_masks, list):
252+
prefill_masks = [prefill_masks]
255253
assert len(prefill_masks) == len(prefill_seq_lens)
256254

257255
decode_token = torch.tensor(
@@ -321,7 +319,7 @@ def _export_helper(
321319
# torch.triu(mask, diagonal=decode_position).unsqueeze(0).unsqueeze(0)
322320
#
323321
sample_kwargs['mask'] = _build_mask(
324-
1, flags.FLAGS.kv_cache_max_len, config.get_causal_mask_value()
322+
1, flags.FLAGS.kv_cache_max_len, config.causal_mask_value
325323
)
326324
if lora is not None:
327325
sample_kwargs['lora'] = lora

0 commit comments

Comments
 (0)