Skip to content

Commit d151eaa

Browse files
committed
feat: extend ResNet-finetune with additional metrics and rename default artifact directory
- Added CUDA, CPU usage, CPU memory, and learning rate metrics for enhanced training analysis. - Renamed default artifact directory to `resnet_finetune` for consistency.
1 parent 29a6c62 commit d151eaa

File tree

1 file changed

+11
-2
lines changed
  • examples/resnet_finetune/src

1 file changed

+11
-2
lines changed

examples/resnet_finetune/src/main.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ use burn::prelude::{Int, Tensor};
2323
use burn::record::CompactRecorder;
2424
use burn::tensor::backend::{AutodiffBackend, Backend};
2525
use burn::train::metric::store::{Aggregate, Direction, Split};
26-
use burn::train::metric::{HammingScore, LossMetric};
26+
use burn::train::metric::{
27+
CpuMemory, CpuUse, CudaMetric, HammingScore, LearningRateMetric, LossMetric,
28+
};
2729
use burn::train::renderer::{
2830
EvaluationName, EvaluationProgress, MetricState, MetricsRenderer, MetricsRendererEvaluation,
2931
MetricsRendererTraining, TrainingProgress,
@@ -57,7 +59,7 @@ pub struct Args {
5759
pub train_percentage: u8,
5860

5961
/// Directory to save the artifacts.
60-
#[arg(long, default_value = "/tmp/resnet-finetune")]
62+
#[arg(long, default_value = "/tmp/resnet_finetune")]
6163
artifact_dir: String,
6264

6365
/// Batch size for processing
@@ -288,6 +290,13 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
288290
.metric_valid_numeric(HammingScore::new())
289291
.metric_train_numeric(LossMetric::new())
290292
.metric_valid_numeric(LossMetric::new())
293+
.metric_train(CudaMetric::new())
294+
.metric_valid(CudaMetric::new())
295+
.metric_train_numeric(CpuUse::new())
296+
.metric_valid_numeric(CpuUse::new())
297+
.metric_train_numeric(CpuMemory::new())
298+
.metric_valid_numeric(CpuMemory::new())
299+
.metric_train_numeric(LearningRateMetric::new())
291300
.with_file_checkpointer(CompactRecorder::new())
292301
.early_stopping(MetricEarlyStoppingStrategy::new(
293302
&LossMetric::<B>::new(),

0 commit comments

Comments
 (0)