Skip to content

Commit 981119e

Browse files
authored
forward fix
Differential Revision: D79597179 Pull Request resolved: #13129
1 parent 60158e8 commit 981119e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
7070

7171
self.scale = float(self.head_dim) ** 0.5
7272

73-
if config.enable_r3:
73+
if hasattr(config, "enable_r3") and config.enable_r3:
7474
self.register_buffer(
7575
"r3_weight",
7676
torch.tensor(
@@ -186,11 +186,11 @@ def forward_sha(
186186
]
187187
for i in range(len(q)):
188188
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
189-
if self.config.enable_r3:
189+
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
190190
q[i] = torch.matmul(q[i], self.r3_weight.T)
191191
for i in range(len(k)):
192192
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin)
193-
if self.config.enable_r3:
193+
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
194194
k[i] = torch.matmul(k[i], self.r3_weight.T)
195195
k[i] = k[i].transpose(1, 2)
196196

0 commit comments

Comments
 (0)