@@ -689,14 +689,48 @@ def impl_update_training_progress(model_name, dataset_name, progress_text, bucke
689689 with open (local_file , "w" ) as f :
690690 f .write (progress_text )
691691 client = make_client ("s3" , region )
692- remote_path = "axon-training-progress/" + os .path .basename (model_name ) + "/" + \
693- os .path .basename (dataset_name ) + "/progress.txt"
692+ remote_path = create_progress_prefix (model_name , dataset_name ) + "/progress.txt"
694693 client .upload_file (path , bucket_name , remote_path )
695694 print ("Updated progress in: {}\n " .format (remote_path ))
696695 finally :
697696 os .remove (path )
698697
699698
699+ def impl_create_heartbeat (model_name , dataset_name , bucket_name , region ):
700+ """
701+ Creates a heartbeat that Axon uses to check if the training script is running properly.
702+
703+ :param model_name: The filename of the model.
704+ :param dataset_name: The filename of the dataset.
705+ :param bucket_name: The S3 bucket name.
706+ :param region: The region, or `None` to pull the region from the environment.
707+ """
708+ client = make_client ("s3" , region )
709+ remote_path = create_progress_prefix (model_name , dataset_name ) + "/heartbeat.txt"
710+ client .put_object (Body = "1" , Bucket = bucket_name , Key = remote_path )
711+ print ("Created heartbeat file in: {}\n " .format (remote_path ))
712+
713+
714+ def impl_remove_heartbeat (model_name , dataset_name , bucket_name , region ):
715+ """
716+ Removes a heartbeat that Axon uses to check if the training script is running properly.
717+
718+ :param model_name: The filename of the model.
719+ :param dataset_name: The filename of the dataset.
720+ :param bucket_name: The S3 bucket name.
721+ :param region: The region, or `None` to pull the region from the environment.
722+ """
723+ client = make_client ("s3" , region )
724+ remote_path = create_progress_prefix (model_name , dataset_name ) + "/heartbeat.txt"
725+ client .put_object (Body = "0" , Bucket = bucket_name , Key = remote_path )
726+ print ("Removed heartbeat file in: {}\n " .format (remote_path ))
727+
728+
729+ def create_progress_prefix (model_name , dataset_name ):
730+ return "axon-training-progress/" + os .path .basename (model_name ) + "/" + \
731+ os .path .basename (dataset_name )
732+
733+
700734@click .group ()
701735def cli ():
702736 return
@@ -902,3 +936,35 @@ def update_training_progress(model_name, dataset_name, progress_text, region):
902936 """
903937 impl_update_training_progress (model_name , dataset_name , progress_text , ensure_s3_bucket (region ),
904938 region )
939+
940+
941+ @cli .command (name = "create-heartbeat" )
942+ @click .argument ("model-name" )
943+ @click .argument ("dataset-name" )
944+ @click .option ("--region" , help = "The region to connect to." ,
945+ type = click .Choice (region_choices ))
946+ def create_heartbeat (model_name , dataset_name , region ):
947+ """
948+ Creates a heartbeat that Axon uses to check if the training script is running properly.
949+
950+ MODEL_NAME The filename of the model currently being trained.
951+
952+ DATASET_NAME The name of the dataset currently being trained on.
953+ """
954+ impl_create_heartbeat (model_name , dataset_name , ensure_s3_bucket (region ), region )
955+
956+
957+ @cli .command (name = "remove-heartbeat" )
958+ @click .argument ("model-name" )
959+ @click .argument ("dataset-name" )
960+ @click .option ("--region" , help = "The region to connect to." ,
961+ type = click .Choice (region_choices ))
962+ def remove_heartbeat (model_name , dataset_name , region ):
963+ """
964+ Removes a heartbeat that Axon uses to check if the training script is running properly.
965+
966+ MODEL_NAME The filename of the model currently being trained.
967+
968+ DATASET_NAME The name of the dataset currently being trained on.
969+ """
970+ impl_remove_heartbeat (model_name , dataset_name , ensure_s3_bucket (region ), region )
0 commit comments