@@ -249,7 +249,8 @@ public override NasBertTransformer Fit(IDataView input)
249
249
for ( int i = 0 ; i < Option . MaxEpoch ; i ++ )
250
250
{
251
251
ch . Trace ( $ "Starting epoch { i } ") ;
252
- trainer . Train ( input ) ;
252
+ Host . CheckAlive ( ) ;
253
+ trainer . Train ( Host , input ) ;
253
254
ch . Trace ( $ "Finished epoch { i } ") ;
254
255
if ( Option . ValidationSet != null )
255
256
trainer . Validate ( pch , ch , i ) ;
@@ -423,7 +424,7 @@ private bool ValidateStep(DataViewRowCursor cursor,
423
424
return cursorValid ;
424
425
}
425
426
426
- public void Train ( IDataView input )
427
+ public void Train ( IHost host , IDataView input )
427
428
{
428
429
// Get the cursor and the correct columns based on the inputs
429
430
DataViewRowCursor cursor = default ;
@@ -443,14 +444,15 @@ public void Train(IDataView input)
443
444
var cursorValid = true ;
444
445
while ( cursorValid )
445
446
{
446
- cursorValid = TrainStep ( cursor , sentence1Getter , sentence2Getter , labelGetter , ref inputTensors , ref targets ) ;
447
+ cursorValid = TrainStep ( host , cursor , sentence1Getter , sentence2Getter , labelGetter , ref inputTensors , ref targets ) ;
447
448
}
448
449
}
449
450
450
- private bool TrainStep ( DataViewRowCursor cursor ,
451
- ValueGetter < ReadOnlyMemory < char > > sentence1Getter ,
452
- ValueGetter < ReadOnlyMemory < char > > sentence2Getter ,
453
- ValueGetter < TLabelCol > labelGetter ,
451
+ private bool TrainStep ( IHost host ,
452
+ DataViewRowCursor cursor ,
453
+ ValueGetter < ReadOnlyMemory < char > > sentence1Getter ,
454
+ ValueGetter < ReadOnlyMemory < char > > sentence2Getter ,
455
+ ValueGetter < TLabelCol > labelGetter ,
454
456
ref List < Tensor > inputTensors ,
455
457
ref List < TTargetsCol > targets )
456
458
{
@@ -461,6 +463,7 @@ private bool TrainStep(DataViewRowCursor cursor,
461
463
var cursorValid = true ;
462
464
for ( int i = 0 ; i < Parent . Option . BatchSize && cursorValid ; i ++ )
463
465
{
466
+ host . CheckAlive ( ) ;
464
467
cursorValid = cursor . MoveNext ( ) ;
465
468
if ( cursorValid )
466
469
{
@@ -479,7 +482,7 @@ private bool TrainStep(DataViewRowCursor cursor,
479
482
}
480
483
481
484
Updates ++ ;
482
-
485
+ host . CheckAlive ( ) ;
483
486
torch . random . manual_seed ( 1 + Updates ) ;
484
487
torch . cuda . manual_seed ( 1 + Updates ) ;
485
488
Model . train ( ) ;
@@ -497,8 +500,10 @@ private bool TrainStep(DataViewRowCursor cursor,
497
500
loss = torch . nn . MSELoss ( reduction : Parent . Option . Reduction ) . forward ( logits , targetsTensor ) ;
498
501
logits = logits . squeeze ( ) ;
499
502
}
500
-
503
+ host . CheckAlive ( ) ;
501
504
loss . backward ( ) ;
505
+
506
+ host . CheckAlive ( ) ;
502
507
OptimizeStep ( ) ;
503
508
504
509
return cursorValid ;
0 commit comments