Skip to content

Commit 5521907

Browse files
authored
update padding strategy for persistent cache (#2464)
1 parent 616480e commit 5521907

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

swift/torchacc_utils.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,26 @@ def get_bucket_sizes(max_length: int) -> List[int]:
2727
the bucket sizes. If not set, we use a normal distribution bucketing with
2828
8 buckets.
2929
"""
30+
padding_p_base = 2
3031
if os.getenv('TORCHACC_DATA_BUCKETS') is not None:
3132
bucket_sizes = [int(x) for x in os.getenv('TORCHACC_DATA_BUCKETS').split(',')]
3233
bucket_sizes.append(max_length)
33-
else: # default normal distribution bucketing.
34-
mean = max_length // 2
35-
var = max_length // 8
36-
bucket_sizes = [mean + i * var for i in range(-3, 4)]
34+
else:
35+
if os.getenv('TORCHACC_CACHE_PATH') is not None: # padding strategy when persistent cache is enabled
36+
padding_p_base = 1.4
37+
padding_p_base = os.getenv('TORCHACC_PADDING_P_BASE', padding_p_base)
38+
try:
39+
padding_p_base = float(padding_p_base)
40+
except ValueError as e:
41+
logger.error(f'Expect TORCHACC_PADDINF_P_BASE to be a float number, but encountered {padding_p_base}')
42+
raise e
43+
bucket_sizes = [16, 32, 48, 64, 96, 128]
44+
base_size = 256
45+
while base_size < max_length:
46+
bucket_sizes.append((int(base_size) + 127) // 128 * 128)
47+
base_size *= padding_p_base
3748
bucket_sizes.append(max_length)
49+
3850
return bucket_sizes
3951

4052

swift/trainers/trainers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=No
213213
acc = torch.tensor(acc_list, device=preds.device).float().mean()
214214
else:
215215
if use_torchacc():
216-
ta_trim_graph()
216+
# Only enabled during evaluation/test
217+
if not model.training:
218+
ta_trim_graph()
217219
preds = preds.to('cpu')
218220
masks = masks.to('cpu')
219221
labels = labels.to('cpu')

0 commit comments

Comments
 (0)