You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Started making changes to use native Pytorch AMP
* Updated compute_loss functions to use torch.cuda.amp.autocast
* Updating docstrings
* Add use_amp to trainer_checkpoint
* Removed mentions of apex and started to add the necessary warnings
* Removing unused instances of use_amp variable
* Added fast training test for FARMReader. Needed to add max_query_length as a parameter in FARMReader.__init__ and FARMReader.train
* Make max_query_length optional in FARMReader.train
* Update lg
Co-authored-by: Agnieszka Marzec <[email protected]>
Co-authored-by: agnieszka-m <[email protected]>
@@ -389,8 +381,7 @@ that gets split off from training data for eval.
389
381
A list containing torch device objects and/or strings is supported (For example
390
382
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
391
383
parameter is not used and a single cpu device is used for inference.
392
-
-`student_batch_size`: Number of samples the student model receives in one batch for training
393
-
-`student_batch_size`: Number of samples the teacher model receives in one batch for distillation
384
+
-`batch_size`: Number of samples the student model and teacher model receives in one batch for training
394
385
-`n_epochs`: Number of iterations on the whole training data set
395
386
-`learning_rate`: Learning rate of the optimizer
396
387
-`max_seq_len`: Maximum text length (in tokens). Everything longer gets cut down.
@@ -402,21 +393,16 @@ Options for different schedules are available in FARM.
402
393
-`num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing.
403
394
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
404
395
Set to None to use all CPU cores minus one.
405
-
-`use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
406
-
Available options:
407
-
None (Don't use AMP)
408
-
"O0" (Normal FP32 training)
409
-
"O1" (Mixed Precision => Recommended)
410
-
"O2" (Almost FP16)
411
-
"O3" (Pure FP16).
412
-
See details on: https://nvidia.github.io/apex/amp.html
396
+
-`use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
397
+
training speed and reduce GPU memory usage.
398
+
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
399
+
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
413
400
-`checkpoint_root_dir`: the Path of directory where all train checkpoints are saved. For each individual
414
401
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
415
402
-`checkpoint_every`: save a train checkpoint after this many steps of training.
416
403
-`checkpoints_to_keep`: maximum number of train checkpoints to save.
417
404
-`caching`: whether or not to use caching for preprocessed dataset and teacher logits
418
405
-`cache_path`: Path to cache the preprocessed dataset and teacher logits
419
-
-`distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
420
406
-`distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
421
407
-`temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
422
408
-`processor`: The processor to use for preprocessing. If None, the default SquadProcessor is used.
@@ -663,7 +649,7 @@ Example:
663
649
**Arguments**:
664
650
665
651
-`question`: Question string
666
-
-`documents`: List of documents as string type
652
+
-`texts`: A listof Document texts as a string type
0 commit comments