Skip to content

Commit 1340e6f

Browse files
added batch_size argument to train cli
1 parent 9b3cf80 commit 1340e6f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

orthophoto-segmentation-benchmark-toolkit/main_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
if __name__ == '__main__':
1212
parser = argparse.ArgumentParser()
1313
parser.add_argument("--epochs", help="model weights of given eperiment will be used for training", default=30)
14+
parser.add_argument("--bs", help="model weights of given eperiment will be used for training", default=4)
1415
parser.add_argument("--experiment", help="model weights of given eperiment will be used for training", default="")
1516
args = parser.parse_args()
1617

@@ -21,7 +22,7 @@
2122
dataset = DroneDeployDataset(dataset, size).download().generate_chips()
2223
model_backend = UnetBackend('efficientnetb3')
2324

24-
experiment = Experiment("test", dataset, model_backend, batch_size=1,
25+
experiment = Experiment("test", dataset, model_backend, batch_size=args.bs,
2526
experiment_directory=args.experiment, load_best=False)
2627
experiment.analyze()
2728
experiment.train(epochs=args.epochs)

0 commit comments

Comments
 (0)