Skip to content

Commit 9969571

Browse files
waitzkincarmocca
andauthored
Add Loss Averaging and no_backward_sync() in Gradient Accumulation Logic (#357)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 0819c7f commit 9969571

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

finetune/adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ def train(
131131
t0 = time.time()
132132

133133
input_ids, targets = get_batch(fabric, train_data)
134-
logits = model(input_ids)
135-
loss = loss_fn(logits, targets)
136134
with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)):
135+
logits = model(input_ids)
136+
loss = loss_fn(logits, targets)
137137
fabric.backward(loss / gradient_accumulation_iters)
138138

139139
if (iter_num + 1) % gradient_accumulation_iters == 0:

finetune/adapter_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ def train(
137137
t0 = time.time()
138138

139139
input_ids, targets = get_batch(fabric, train_data)
140-
logits = model(input_ids)
141-
loss = loss_fn(logits, targets)
142140
with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)):
141+
logits = model(input_ids)
142+
loss = loss_fn(logits, targets)
143143
fabric.backward(loss / gradient_accumulation_iters)
144144

145145
if (iter_num + 1) % gradient_accumulation_iters == 0:

finetune/full.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,12 @@ def train(
114114
param_group['lr'] = lr
115115

116116
t0 = time.time()
117-
117+
118+
input_ids, targets = get_batch(fabric, train_data)
118119
with fabric.no_backward_sync(model, enabled=is_accumulating):
119-
input_ids, targets = get_batch(fabric, train_data)
120120
logits = model(input_ids)
121121
loss = loss_fn(logits, targets)
122-
fabric.backward(loss)
122+
fabric.backward(loss / gradient_accumulation_iters)
123123

124124
if not is_accumulating:
125125
optimizer.step()

finetune/lora.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,10 @@ def train(
108108
t0 = time.time()
109109

110110
input_ids, targets = get_batch(fabric, train_data)
111-
logits = model(input_ids)
112-
loss = loss_fn(logits, targets)
113-
fabric.backward(loss)
111+
with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)):
112+
logits = model(input_ids)
113+
loss = loss_fn(logits, targets)
114+
fabric.backward(loss / gradient_accumulation_iters)
114115

115116
if (iter_num + 1) % gradient_accumulation_iters == 0:
116117
optimizer.step()

0 commit comments

Comments
 (0)