Skip to content

Commit 69b67c4

Browse files
fix #6416 (#6612)
* remove handler * update * checkout test file
1 parent 7f94445 commit 69b67c4

File tree

1 file changed

+70
-71
lines changed

1 file changed

+70
-71
lines changed

src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs

Lines changed: 70 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -248,94 +248,93 @@ public async Task<TrialResult> RunAsync(CancellationToken ct = default)
248248
var parameter = tuner.Propose(trialSettings);
249249
trialSettings.Parameter = parameter;
250250

251-
using (var trialCancellationTokenSource = new CancellationTokenSource())
251+
var trialCancellationTokenSource = new CancellationTokenSource();
252+
monitor?.ReportRunningTrial(trialSettings);
253+
var stopTrialManager = new CancellationTokenStopTrainingManager(trialCancellationTokenSource.Token, null);
254+
aggregateTrainingStopManager.AddTrainingStopManager(stopTrialManager);
255+
void handler(object o, EventArgs e)
252256
{
253-
monitor?.ReportRunningTrial(trialSettings);
254-
255-
void handler(object o, EventArgs e)
256-
{
257-
trialCancellationTokenSource.Cancel();
258-
}
259-
try
257+
trialCancellationTokenSource.Cancel();
258+
}
259+
try
260+
{
261+
using (var performanceMonitor = serviceProvider.GetService<IPerformanceMonitor>())
262+
using (var runner = serviceProvider.GetRequiredService<ITrialRunner>())
260263
{
261-
using (var performanceMonitor = serviceProvider.GetService<IPerformanceMonitor>())
262-
using (var runner = serviceProvider.GetRequiredService<ITrialRunner>())
264+
aggregateTrainingStopManager.OnStopTraining += handler;
265+
performanceMonitor.PerformanceMetricsUpdated += (o, metrics) =>
263266
{
264-
aggregateTrainingStopManager.OnStopTraining += handler;
265-
performanceMonitor.PerformanceMetricsUpdated += (o, metrics) =>
266-
{
267-
performanceMonitor.OnPerformanceMetricsUpdatedHandler(trialSettings, metrics, trialCancellationTokenSource);
268-
};
269-
270-
performanceMonitor.Start();
271-
logger.Trace($"trial setting - {JsonSerializer.Serialize(trialSettings)}");
272-
var trialResult = await runner.RunAsync(trialSettings, trialCancellationTokenSource.Token);
273-
274-
var peakCpu = performanceMonitor?.GetPeakCpuUsage();
275-
var peakMemoryInMB = performanceMonitor?.GetPeakMemoryUsageInMegaByte();
276-
trialResult.PeakCpu = peakCpu;
277-
trialResult.PeakMemoryInMegaByte = peakMemoryInMB;
278-
trialResult.TrialSettings.EndedAtUtc = DateTime.UtcNow;
279-
280-
performanceMonitor.Pause();
281-
monitor?.ReportCompletedTrial(trialResult);
282-
tuner.Update(trialResult);
283-
trialResultManager?.AddOrUpdateTrialResult(trialResult);
284-
aggregateTrainingStopManager.Update(trialResult);
285-
286-
var loss = trialResult.Loss;
287-
if (loss < _bestLoss)
288-
{
289-
_bestTrialResult = trialResult;
290-
_bestLoss = loss;
291-
monitor?.ReportBestTrial(trialResult);
292-
}
293-
}
294-
}
295-
catch (Exception ex) when (aggregateTrainingStopManager.IsStopTrainingRequested() == false)
296-
{
297-
var exceptionMessage = $@"
298-
Exception thrown during Trial {trialSettings.TrialId} with configuration {JsonSerializer.Serialize(trialSettings)}
267+
performanceMonitor.OnPerformanceMetricsUpdatedHandler(trialSettings, metrics, trialCancellationTokenSource);
268+
};
299269

300-
Exception Details: ex.Message
270+
performanceMonitor.Start();
271+
logger.Trace($"trial setting - {JsonSerializer.Serialize(trialSettings)}");
272+
var trialResult = await runner.RunAsync(trialSettings, trialCancellationTokenSource.Token);
301273

302-
Abandoning Trial {trialSettings.TrialId} and continue training.
303-
";
304-
logger.Trace(exceptionMessage);
305-
trialSettings.EndedAtUtc = DateTime.UtcNow;
306-
monitor?.ReportFailTrial(trialSettings, ex);
307-
var trialResult = new TrialResult
308-
{
309-
TrialSettings = trialSettings,
310-
Loss = double.MaxValue,
311-
};
274+
var peakCpu = performanceMonitor?.GetPeakCpuUsage();
275+
var peakMemoryInMB = performanceMonitor?.GetPeakMemoryUsageInMegaByte();
276+
trialResult.PeakCpu = peakCpu;
277+
trialResult.PeakMemoryInMegaByte = peakMemoryInMB;
278+
trialResult.TrialSettings.EndedAtUtc = DateTime.UtcNow;
312279

280+
performanceMonitor.Pause();
281+
monitor?.ReportCompletedTrial(trialResult);
313282
tuner.Update(trialResult);
314283
trialResultManager?.AddOrUpdateTrialResult(trialResult);
315284
aggregateTrainingStopManager.Update(trialResult);
316285

317-
if (ex is not OperationCanceledException && _bestTrialResult == null)
286+
var loss = trialResult.Loss;
287+
if (loss < _bestLoss)
318288
{
319-
logger.Trace($"trial fatal error - {JsonSerializer.Serialize(trialSettings)}, stop training");
320-
321-
// TODO
322-
// it's questionable on whether to abort the entire training process
323-
// for a single fail trial. We should make it an option and only exit
324-
// when error is fatal (like schema mismatch).
325-
throw;
289+
_bestTrialResult = trialResult;
290+
_bestLoss = loss;
291+
monitor?.ReportBestTrial(trialResult);
326292
}
327-
continue;
328293
}
329-
catch (Exception) when (aggregateTrainingStopManager.IsStopTrainingRequested())
294+
}
295+
catch (Exception ex) when (aggregateTrainingStopManager.IsStopTrainingRequested() == false)
296+
{
297+
var exceptionMessage = $@"
298+
Exception thrown during Trial {trialSettings.TrialId} with configuration {JsonSerializer.Serialize(trialSettings)}
299+
300+
Exception Details: {ex.Message}
301+
302+
Abandoning Trial {trialSettings.TrialId} and continue training.
303+
";
304+
logger.Trace(exceptionMessage);
305+
trialSettings.EndedAtUtc = DateTime.UtcNow;
306+
monitor?.ReportFailTrial(trialSettings, ex);
307+
var trialResult = new TrialResult
330308
{
331-
logger.Trace($"trial cancelled - {JsonSerializer.Serialize(trialSettings)}, stop training");
309+
TrialSettings = trialSettings,
310+
Loss = double.MaxValue,
311+
};
332312

333-
break;
334-
}
335-
finally
313+
tuner.Update(trialResult);
314+
trialResultManager?.AddOrUpdateTrialResult(trialResult);
315+
aggregateTrainingStopManager.Update(trialResult);
316+
317+
if (ex is not OperationCanceledException && _bestTrialResult == null)
336318
{
337-
aggregateTrainingStopManager.OnStopTraining -= handler;
319+
logger.Trace($"trial fatal error - {JsonSerializer.Serialize(trialSettings)}, stop training");
320+
321+
// TODO
322+
// it's questionable on whether to abort the entire training process
323+
// for a single fail trial. We should make it an option and only exit
324+
// when error is fatal (like schema mismatch).
325+
throw;
338326
}
327+
continue;
328+
}
329+
catch (Exception) when (aggregateTrainingStopManager.IsStopTrainingRequested())
330+
{
331+
logger.Trace($"trial cancelled - {JsonSerializer.Serialize(trialSettings)}, stop training");
332+
333+
break;
334+
}
335+
finally
336+
{
337+
aggregateTrainingStopManager.OnStopTraining -= handler;
339338
}
340339
}
341340

0 commit comments

Comments
 (0)