Skip to content

Commit a06dadc

Browse files
add checkAlive in NasBertTrainer (#6546)
1 parent eeba2ee commit a06dadc

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ public override NasBertTransformer Fit(IDataView input)
249249
for (int i = 0; i < Option.MaxEpoch; i++)
250250
{
251251
ch.Trace($"Starting epoch {i}");
252-
trainer.Train(input);
252+
Host.CheckAlive();
253+
trainer.Train(Host, input);
253254
ch.Trace($"Finished epoch {i}");
254255
if (Option.ValidationSet != null)
255256
trainer.Validate(pch, ch, i);
@@ -423,7 +424,7 @@ private bool ValidateStep(DataViewRowCursor cursor,
423424
return cursorValid;
424425
}
425426

426-
public void Train(IDataView input)
427+
public void Train(IHost host, IDataView input)
427428
{
428429
// Get the cursor and the correct columns based on the inputs
429430
DataViewRowCursor cursor = default;
@@ -443,14 +444,15 @@ public void Train(IDataView input)
443444
var cursorValid = true;
444445
while (cursorValid)
445446
{
446-
cursorValid = TrainStep(cursor, sentence1Getter, sentence2Getter, labelGetter, ref inputTensors, ref targets);
447+
cursorValid = TrainStep(host, cursor, sentence1Getter, sentence2Getter, labelGetter, ref inputTensors, ref targets);
447448
}
448449
}
449450

450-
private bool TrainStep(DataViewRowCursor cursor,
451-
ValueGetter<ReadOnlyMemory<char>> sentence1Getter,
452-
ValueGetter<ReadOnlyMemory<char>> sentence2Getter,
453-
ValueGetter<TLabelCol> labelGetter,
451+
private bool TrainStep(IHost host,
452+
DataViewRowCursor cursor,
453+
ValueGetter<ReadOnlyMemory<char>> sentence1Getter,
454+
ValueGetter<ReadOnlyMemory<char>> sentence2Getter,
455+
ValueGetter<TLabelCol> labelGetter,
454456
ref List<Tensor> inputTensors,
455457
ref List<TTargetsCol> targets)
456458
{
@@ -461,6 +463,7 @@ private bool TrainStep(DataViewRowCursor cursor,
461463
var cursorValid = true;
462464
for (int i = 0; i < Parent.Option.BatchSize && cursorValid; i++)
463465
{
466+
host.CheckAlive();
464467
cursorValid = cursor.MoveNext();
465468
if (cursorValid)
466469
{
@@ -479,7 +482,7 @@ private bool TrainStep(DataViewRowCursor cursor,
479482
}
480483

481484
Updates++;
482-
485+
host.CheckAlive();
483486
torch.random.manual_seed(1 + Updates);
484487
torch.cuda.manual_seed(1 + Updates);
485488
Model.train();
@@ -497,8 +500,10 @@ private bool TrainStep(DataViewRowCursor cursor,
497500
loss = torch.nn.MSELoss(reduction: Parent.Option.Reduction).forward(logits, targetsTensor);
498501
logits = logits.squeeze();
499502
}
500-
503+
host.CheckAlive();
501504
loss.backward();
505+
506+
host.CheckAlive();
502507
OptimizeStep();
503508

504509
return cursorValid;

0 commit comments

Comments
 (0)