Skip to content

Commit 1ddb636

Browse files
author
David Thrower
committed
Improve the MlFlow setup workflow.
1 parent 64c204b commit 1ddb636

File tree

1 file changed

+22
-29
lines changed

1 file changed

+22
-29
lines changed

train_a_generative_llm_docker.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,26 @@
6262
keras_models_folder = f"{ARTIFACTS_FOLDER}/{TIME}/keras_models"
6363
Path(keras_models_folder).mkdir(parents=True, exist_ok=True)
6464

65+
## Dataset Selection
66+
# Assumes:
67+
# 1. Is a huggingface dataset of the structure ...
68+
# 2. Has a key ['train']['text']
69+
# 3. The key duck types as a List[str]
70+
# 4. The samples tokenize consistent with the MAX_SEQUENCE_LENGTH
6571

66-
MLFLOW_PORT = int(os.getenv("MLFLOW_PORT", 7777))
72+
DATASET_TO_RUN = str(os.getenv("DATASET_TO_RUN", "david-thrower/tiny-stories-mini-96-seq-len-50000-samples"))
6773

74+
######################### here ######################
6875

76+
# Samples to use for the neural architecture seaerch stage
77+
PHASE_I_A_SAMPLES_TO_CREATE = int(getenv("PHASE_I_A_SAMPLES_TO_CREATE", "300"))
6978

79+
# Samples to use for the main training stage
80+
PHASE_I_B_SAMPLES_TO_CREATE = int(getenv("PHASE_I_B_SAMPLES_TO_CREATE", "200"))
81+
PHASE_I_B_VAL_SPLIT = float(getenv("PHASE_I_B_VAL_SPLIT", "0.15"))
7082

7183

84+
MLFLOW_PORT = int(os.getenv("MLFLOW_PORT", 7777))
7285

7386
# If you don't want Mlflow, just add `-e MLFLOW_PORT=0` to `docker run`
7487
if MLFLOW_PORT != 0:
@@ -86,38 +99,18 @@
8699
])
87100

88101
answer = subprocess.run(cmd, shell=True)
89-
time.sleep(10)
102+
time.sleep(30)
90103
print(answer.stdout)
91104

92105

106+
# Set up MlFlow experiment
107+
time_hyphenated = TIME.replace('_','-')
108+
ds_root_name = DATASET_TO_RUN.split('/')[-1]
109+
MLFLOW_EXPERIMENT_NAME = f"{time_hyphenated}--llm-training--{ds_root_name}-" +\
110+
f"ia-{PHASE_I_A_SAMPLES_TO_CREATE}-ib-{PHASE_I_B_SAMPLES_TO_CREATE}-a"
93111

94-
## Dataset Selection
95-
# Assumes:
96-
# 1. Is a huggingface dataset of the structure ...
97-
# 2. Has a key ['train']['text']
98-
# 3. The key duck types as a List[str]
99-
# 4. The samples tokenize consistent with the MAX_SEQUENCE_LENGTH
100-
101-
DATASET_TO_RUN = str(os.getenv("DATASET_TO_RUN", "david-thrower/tiny-stories-mini-96-seq-len-50000-samples"))
102-
103-
######################### here ######################
104-
105-
# Samples to use for the neural architecture seaerch stage
106-
PHASE_I_A_SAMPLES_TO_CREATE = int(getenv("PHASE_I_A_SAMPLES_TO_CREATE", "300"))
107-
108-
# Samples to use for the main training stage
109-
PHASE_I_B_SAMPLES_TO_CREATE = int(getenv("PHASE_I_B_SAMPLES_TO_CREATE", "200"))
110-
PHASE_I_B_VAL_SPLIT = float(getenv("PHASE_I_B_VAL_SPLIT", "0.15"))
111-
112-
113-
# Set up MlFlow experiment
114-
time_hyphenated = TIME.replace('_','-')
115-
ds_root_name = DATASET_TO_RUN.split('/')[-1]
116-
EXPERIMENT_NAME = f"{time_hyphenated}--llm-training--{ds_root_name}-" +\
117-
f"ia-{PHASE_I_A_SAMPLES_TO_CREATE}-ib-{PHASE_I_B_SAMPLES_TO_CREATE}-a"
118-
119-
mlflow.set_tracking_uri(uri=f"http://127.0.0.1:{MLFLOW_PORT}")
120-
mlflow.set_experiment(EXPERIMENT_NAME)
112+
mlflow.set_tracking_uri(uri=f"http://127.0.0.1:{MLFLOW_PORT}")
113+
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)
121114

122115

123116
# This is a single head model. It only returns the next token. For this reason,

0 commit comments

Comments
 (0)