File tree Expand file tree Collapse file tree 4 files changed +11
-10
lines changed Expand file tree Collapse file tree 4 files changed +11
-10
lines changed Original file line number Diff line number Diff line change @@ -131,9 +131,9 @@ def train(
131131 t0 = time .time ()
132132
133133 input_ids , targets = get_batch (fabric , train_data )
134- logits = model (input_ids )
135- loss = loss_fn (logits , targets )
136134 with fabric .no_backward_sync (model , enabled = ((iter_num + 1 ) % gradient_accumulation_iters != 0 )):
135+ logits = model (input_ids )
136+ loss = loss_fn (logits , targets )
137137 fabric .backward (loss / gradient_accumulation_iters )
138138
139139 if (iter_num + 1 ) % gradient_accumulation_iters == 0 :
Original file line number Diff line number Diff line change @@ -137,9 +137,9 @@ def train(
137137 t0 = time .time ()
138138
139139 input_ids , targets = get_batch (fabric , train_data )
140- logits = model (input_ids )
141- loss = loss_fn (logits , targets )
142140 with fabric .no_backward_sync (model , enabled = ((iter_num + 1 ) % gradient_accumulation_iters != 0 )):
141+ logits = model (input_ids )
142+ loss = loss_fn (logits , targets )
143143 fabric .backward (loss / gradient_accumulation_iters )
144144
145145 if (iter_num + 1 ) % gradient_accumulation_iters == 0 :
Original file line number Diff line number Diff line change @@ -114,12 +114,12 @@ def train(
114114 param_group ['lr' ] = lr
115115
116116 t0 = time .time ()
117-
117+
118+ input_ids , targets = get_batch (fabric , train_data )
118119 with fabric .no_backward_sync (model , enabled = is_accumulating ):
119- input_ids , targets = get_batch (fabric , train_data )
120120 logits = model (input_ids )
121121 loss = loss_fn (logits , targets )
122- fabric .backward (loss )
122+ fabric .backward (loss / gradient_accumulation_iters )
123123
124124 if not is_accumulating :
125125 optimizer .step ()
Original file line number Diff line number Diff line change @@ -108,9 +108,10 @@ def train(
108108 t0 = time .time ()
109109
110110 input_ids , targets = get_batch (fabric , train_data )
111- logits = model (input_ids )
112- loss = loss_fn (logits , targets )
113- fabric .backward (loss )
111+ with fabric .no_backward_sync (model , enabled = ((iter_num + 1 ) % gradient_accumulation_iters != 0 )):
112+ logits = model (input_ids )
113+ loss = loss_fn (logits , targets )
114+ fabric .backward (loss / gradient_accumulation_iters )
114115
115116 if (iter_num + 1 ) % gradient_accumulation_iters == 0 :
116117 optimizer .step ()
You can’t perform that action at this time.
0 commit comments