-
Notifications
You must be signed in to change notification settings - Fork 285
Expand file tree
/
Copy pathhp_sweep_gpt.py
More file actions
843 lines (662 loc) · 28.3 KB
/
hp_sweep_gpt.py
File metadata and controls
843 lines (662 loc) · 28.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
# ---
# cmd: ["modal", "run", "06_gpu_and_ml/hyperparameter-sweep/hp_sweep_gpt.py", "--n-steps", "200", "--n-steps-before-checkpoint", "50", "--n-steps-before-eval", "50"]
# ---
# # Train an SLM from scratch with early-stopping grid search over hyperparameters
# 
# When you want a language model that performs well on your task, there are three options,
# ordered by the degree of customization:
# - [**Prompt Engineering**](https://en.wikipedia.org/wiki/Prompt_engineering):
# large and capable language models understand tasks in natural language, so you can
# carefully design a natural language "prompt" to elicit the desired behavior.
# - [**Fine-Tuning**](https://modal.com/docs/examples/llm-finetuning):
# those same language models were trained by gradient descent on data sets representing tasks,
# and they can be further trained by gradient descent on data sets representative of your task.
# - **Training from Scratch**:
# if you have enough data for your task, you can throw the pretrained model away and make your own.
# Each step adds additional engineering complexity, but also leads to a superior cost-performance Pareto frontier
# for your tasks. Fine-tuned models at one-tenth the size regularly outperform more generic models,
# and models trained from scratch outperform them.
# Because these models are so much smaller than the Large Language Models that power generic
# assistant chatbots like ChatGPT and Claude, they are often called _Small Language Models_ (SLMs).
# In this example, we will explore training an SLM from scratch on Modal.
# In fact, we'll train 8 SLMs in parallel with different hyperparameters
# and then select the best one for additional training.
# We'll monitor this training live and serve our training and trained models
# as web endpoints and simple browser UIs.
# Along the way we'll use many features of the Modal platform:
# [distributed volumes](https://modal.com/docs/guide/volumes),
# multiple [web endpoints](https://modal.com/docs/guide/webhooks),
# and [parallel container execution](https://modal.com/docs/guide/scale#parallel-execution-of-inputs).
# Together, these features give every machine learning and AI team
# the same infrastructural capabilities that the most sophisticated companies
# have in their internal platforms.
# ## Basic Setup
import logging as L
import urllib.request
from dataclasses import dataclass
from pathlib import Path, PosixPath
import modal
from pydantic import BaseModel
MINUTES = 60 # seconds
HOURS = 60 * MINUTES
app_name = "example-hp-sweep-gpt"
app = modal.App(app_name)
# We'll use A10G GPUs for training, which are able to train the model to recognizably improved performance
# in ~15 minutes while keeping costs under ~$1.
gpu = "A10G"
# ### Create a Volume to store data, weights, and logs
# Since we'll be coordinating training across multiple machines we'll use a
# distributed [Volume](https://modal.com/docs/guide/volumes)
# to store the data, checkpointed models, and TensorBoard logs.
volume = modal.Volume.from_name(
"example-hp-sweep-gpt-volume", create_if_missing=True
)
volume_path = PosixPath("/vol/data")
model_filename = "nano_gpt_model.pt"
best_model_filename = "best_nano_gpt_model.pt"
tb_log_path = volume_path / "tb_logs"
model_save_path = volume_path / "models"
# ### Define dependencies in container images
# The container image for training is based on Modal's default slim Debian Linux image with `torch`
# for defining and running our neural network and `tensorboard` for monitoring training.
base_image = modal.Image.debian_slim(python_version="3.11").pip_install(
"pydantic==2.9.1"
)
torch_image = base_image.pip_install(
"torch==2.1.2",
"tensorboard==2.17.1",
"numpy<2",
)
# We also have some local dependencies that we'll need to import into the remote environment.
# We add them into the remote container.
torch_image = torch_image.add_local_dir(
Path(__file__).parent / "src", remote_path="/root/src"
)
# We'll serve a simple web endpoint:
web_image = base_image.pip_install(
"fastapi[standard]==0.115.4", "starlette==0.41.2"
)
# And we'll deploy a web UI for interacting with our trained models using Gradio.
assets_path = Path(__file__).parent / "assets"
ui_image = web_image.pip_install("gradio~=4.44.0").add_local_dir(
assets_path, remote_path="/assets"
)
# We can also "pre-import" libraries that will be used by the functions we run on Modal in a given image
# using the `with image.imports` context manager.
with torch_image.imports():
import glob
import os
from timeit import default_timer as timer
import tensorboard
import torch
from src.dataset import Dataset
from src.logs_manager import LogsManager
from src.model import AttentionModel
from src.tokenizer import Tokenizer
# ## Running SLM training on Modal
# Here we define the training function, wrapping it in a decorator
# that specifies the infrastructural parameters, like the container `image` we want to use,
# which `volume` to mount where, the `gpu` we're using, and so on.
# Training consists of specifying optimization parameters, loading the
# `dataset`, building the `model`, setting up TensorBoard logging &
# checkpointing, and then finally executing the `training_loop` itself.
@app.function(
image=torch_image,
volumes={volume_path: volume},
gpu=gpu,
timeout=1 * HOURS,
)
def train_model(
node_rank,
n_nodes,
hparams,
experiment_name,
run_to_first_save=False,
n_steps=3000,
n_steps_before_eval=None,
n_steps_before_checkpoint=None,
):
# optimizer, data, and model prep
batch_size = 64
learning_rate = 3e-4
n_eval_steps = 100
if n_steps_before_eval is None:
n_steps_before_eval = int(n_steps / 8) # eval eight times per run
if n_steps_before_checkpoint is None:
n_steps_before_checkpoint = int(n_steps / 4) # save four times per run
train_percent = 0.9
L.basicConfig(
level=L.INFO,
format=f"\033[0;32m%(asctime)s %(levelname)s [%(filename)s.%(funcName)s:%(lineno)d] [Node {node_rank + 1}] %(message)s\033[0m",
datefmt="%b %d %H:%M:%S",
)
# use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
L.info("Remote Device: %s // GPU: %s", device, gpu)
input_file_path = volume_path / "shakespeare_char.txt"
text = prepare_data(input_file_path, volume)
# construct tokenizer & dataset
tokenizer = Tokenizer(text)
dataset = Dataset(
tokenizer.encode(text),
train_percent,
batch_size,
hparams.context_size,
device,
)
# build the model
model = build_model(hparams, tokenizer.vocab_size, device)
num_parameters = sum(p.numel() for p in model.parameters())
L.info(f"Num parameters: {num_parameters}")
optimizer = setup_optimizer(model, learning_rate)
# TensorBoard logging & checkpointing prep
logs_manager = LogsManager(
experiment_name, hparams, num_parameters, tb_log_path
)
L.info(f"Model name: {logs_manager.model_name}")
model_save_dir = model_save_path / experiment_name / logs_manager.model_name
if model_save_dir.exists():
L.info("Loading model from checkpoint...")
checkpoint = torch.load(str(model_save_dir / model_filename))
is_best_model = not run_to_first_save
if is_best_model:
make_best_symbolic_link(
model_save_dir, model_filename, experiment_name
)
model.load_state_dict(checkpoint["model"])
start_step = checkpoint["steps"] + 1
else:
model_save_dir.mkdir(parents=True, exist_ok=True)
start_step = 0
checkpoint = init_checkpoint(
model, tokenizer, optimizer, start_step, hparams
)
checkpoint_path = model_save_dir / model_filename
out = training_loop(
start_step,
n_steps,
n_steps_before_eval,
n_steps_before_checkpoint,
n_eval_steps,
dataset,
tokenizer,
model,
optimizer,
logs_manager,
checkpoint,
checkpoint_path,
run_to_first_save,
)
return node_rank, float(out["val"]), hparams
# ## Launch a hyperparameter sweep from a `local_entrypoint`
# The main entry point coordinates the hyperparameter optimization.
# First we specify the default hyperparameters for the model, taken from
# [Andrej Karpathy's walkthrough](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=5976s).
# For better performance, you can increase the `context_size` and scale up the GPU accordingly.
@dataclass
class ModelHyperparameters:
n_heads: int = 6
n_embed: int = 384
n_blocks: int = 6
context_size: int = 256
dropout: float = 0.2
# Next we define the local entrypoint: the code we run locally to coordinate training.
# It will train 8 models in parallel across 8 containers, each
# with different hyperparameters, varying the number of heads (`n_heads`), the
# `context_size` (called the "block size" by Karpathy), and the dropout rate (`dropout`). To run in
# parallel we need to use the [`starmap` method](https://modal.com/docs/guide/scale#parallel-execution-of-inputs).
# We train all of the models until the first checkpoint and then stop early so we
# can compare the validation losses.
# Then we restart training for the best model and train it to completion.
# You can kick off training with the following command:
# ```bash
# modal run 06_gpu_and_ml.hyperparameter-sweep.hp_sweep_gpt
# ```
# The output will look something like this:
# ```
# Sep 16 21:20:39 INFO [hp_sweep_gpt.py.train_model:127] [Node 1] Remote Device: cuda // GPU: A10G
# Sep 16 21:20:40 INFO [hp_sweep_gpt.py.train_model:149] [Node 1] Num parameters: 10693697
# Sep 16 21:20:40 INFO [hp_sweep_gpt.py.train_model:156] [Node 1] Model Name: E2024-0916-142031.618259_context_size=8_n_heads=1_dropout=0.1
# Sep 16 21:20:41 INFO [hp_sweep_gpt.py.train_model:225] [Node 1] 0) // 1.03s // Train Loss: 3.58 // Val Loss: 3.60
# Sep 16 21:20:41 INFO [hp_sweep_gpt.py.train_model:127] [Node 2] Remote Device: cuda // GPU: A10G
# ...
# ```
# The `local_entrypoint` code is below. Note that the arguments to it can also be passed via the command line.
# Use `--help` for details.
@app.local_entrypoint()
def main(
n_steps: int = 3000,
n_steps_before_checkpoint: int = None,
n_steps_before_eval: int = None,
):
from datetime import datetime
from itertools import product
experiment_name = f"E{datetime.now().strftime('%Y-%m-%d-%H%M%S.%f')}"
default_hparams = ModelHyperparameters()
# build list of hyperparameters to train & validate
nheads_options = (1, default_hparams.n_heads)
context_size_options = (8, default_hparams.context_size)
dropout_options = (0.1, default_hparams.dropout)
hparams_list = [
ModelHyperparameters(n_heads=h, context_size=c, dropout=d)
for h, c, d in product(
nheads_options, context_size_options, dropout_options
)
]
# run training for each hyperparameter setting
results = []
stop_early = True # stop early so we can compare val losses
print(f"Testing {len(hparams_list)} hyperparameter settings")
n_nodes = len(hparams_list)
static_params = (
experiment_name,
stop_early,
n_steps,
n_steps_before_eval,
n_steps_before_checkpoint,
)
for result in train_model.starmap(
[(i, n_nodes, h, *static_params) for i, h in enumerate(hparams_list)],
order_outputs=False,
):
# result = (node_rank, val_loss, hparams)
node_rank = result[0]
results.append(result)
print(
f"[Node {node_rank + 1}/{n_nodes}] Finished."
f" Early stop val loss result: {result[1:]}"
)
# find the model and hparams with the lowest validation loss
best_result = min(results, key=lambda x: x[1])
print(f"Best early stop val loss result: {best_result}")
best_hparams = best_result[-1]
# finish training with best hparams
node_rank = 0
n_nodes = 1 # only one node for final training run
train_model.remote(
node_rank,
n_nodes,
best_hparams,
experiment_name,
not stop_early,
n_steps,
n_steps_before_eval,
n_steps_before_checkpoint,
)
# ### Monitor experiments with TensorBoard
# To monitor our training we will create a TensorBoard WSGI web app, which will
# display the progress of our training across all 8 models. We'll use the latest
# logs for the most recent experiment written to the Volume.
# To ensure a unique color per experiment you can click the palette (🎨) icon
# under TensorBoard > Time Series > Run and use the Regex:
# `E(\d{4})-(\d{2})-(\d{2})-(\d{6})\.(\d{6})`
# You can deploy this TensorBoard service by running
# ```
# modal deploy 06_gpu_and_ml.hyperparameter-sweep.hp_sweep_gpt
# ```
# and visit it at the URL that ends with `-monitor-training.modal.run`.
# After training finishes, your TensorBoard UI will look something like this:
# 
# You can also find some sample text generated by the model in the "Text" tab.
@app.function(
image=torch_image,
volumes={volume_path: volume},
allow_concurrent_inputs=1000,
)
@modal.wsgi_app()
def monitor_training():
import time
print("📈 TensorBoard: Waiting for logs...")
ct = 0
while not tb_log_path.exists():
ct += 1
if ct > 10:
raise Exception("No logs found after 10 seconds.")
volume.reload() # make sure we have the latest data.
time.sleep(1)
# start TensorBoard server looking at all experiments
board = tensorboard.program.TensorBoard()
board.configure(logdir=str(tb_log_path))
(data_provider, deprecated_multiplexer) = board._make_data_provider()
wsgi_app = tensorboard.backend.application.TensorBoardWSGIApp(
board.flags,
board.plugin_loaders,
data_provider,
board.assets_zip_provider,
deprecated_multiplexer,
)
return wsgi_app
# Notice that there are 8 models training, and the one with the lowest
# validation loss at step 600 continues training to 3000 steps.
# ## Serving SLMs on Modal during and after training
# Because our weights are stored in a distributed Volume,
# we can deploy an inference endpoint based off of them without any extra work --
# and we can even check in on models while we're still training them!
# ### Remote inference with Modal `Cls`es
# We wrap our inference in a Modal `Cls` called `ModelInference`.
# The user of `ModelInference` can control which model is used by providing the
# `experiment_name`. Each unique choice creates a separate
# [auto-scaling deployment](https://modal.com/docs/guide/parameterized-functions).
# If the user does not specify an `experiment_name`, the latest experiment
# is used.
@app.cls(image=torch_image, volumes={volume_path: volume}, gpu=gpu)
class ModelInference:
experiment_name: str = modal.parameter(default="")
def get_latest_available_model_dirs(self, n_last):
"""Find the latest models that have a best model checkpoint saved."""
save_model_dirs = glob.glob(f"{model_save_path}/*")
sorted_model_dirs = sorted(
save_model_dirs, key=os.path.getctime, reverse=True
)
valid_model_dirs = []
for latest_model_dir in sorted_model_dirs:
if Path(f"{latest_model_dir}/{best_model_filename}").exists():
valid_model_dirs.append(Path(latest_model_dir))
if len(valid_model_dirs) >= n_last:
return valid_model_dirs
return valid_model_dirs
@modal.method()
def get_latest_available_experiment_names(self, n_last):
return [d.name for d in self.get_latest_available_model_dirs(n_last)]
def load_model_impl(self):
from .src.model import AttentionModel
from .src.tokenizer import Tokenizer
if self.experiment_name != "": # user selected model
use_model_dir = f"{model_save_path}/{self.experiment_name}"
else: # otherwise, pick latest
try:
use_model_dir = self.get_latest_available_model_dirs(1)[0]
except IndexError:
raise ValueError("No models available to load.")
if self.use_model_dir == use_model_dir and self.is_fully_trained:
return # already loaded fully trained model.
print(f"Loading experiment: {Path(use_model_dir).name}...")
checkpoint = torch.load(f"{use_model_dir}/{best_model_filename}")
self.use_model_dir = use_model_dir
hparams = checkpoint["hparams"]
key = ( # for backwards compatibility
"unique_chars" if "unique_chars" in checkpoint else "chars"
)
unique_chars = checkpoint[key]
steps = checkpoint["steps"]
val_loss = checkpoint["val_loss"]
self.is_fully_trained = checkpoint["finished_training"]
print(
f"Loaded model with {steps} train steps"
f" and val loss of {val_loss:.2f}"
f" (fully_trained={self.is_fully_trained})"
)
self.tokenizer = Tokenizer(unique_chars)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = AttentionModel(
self.tokenizer.vocab_size, hparams, self.device
)
self.model.load_state_dict(checkpoint["model"])
self.model.to(self.device)
@modal.enter()
def load_model(self):
self.use_model_dir = None
self.is_fully_trained = False
self.load_model_impl()
@modal.method()
def generate(self, prompt):
self.load_model_impl() # load updated model if available
n_new_tokens = 1000
return self.model.generate_from_text(
self.tokenizer, prompt, n_new_tokens
)
# ### Adding a simple `web_endpoint`
# The `ModelInference` class above is available for use
# from any other Python environment with the right Modal credentials
# and the `modal` package installed -- just use [`lookup`](https://modal.com/docs/reference/modal.Cls#lookup).
# But we can also expose it as a web endpoint for easy access
# from anywhere, including other programming languages or the command line.
class GenerationRequest(BaseModel):
prompt: str
@app.function(image=web_image)
@modal.web_endpoint(method="POST", docs=True)
def web_generate(request: GenerationRequest):
output = ModelInference().generate.remote(request.prompt)
return {"output": output}
# This endpoint can be deployed on Modal with `modal deploy`.
# That will allow us to generate text via a simple `curl` command like this:
# ```bash
# curl -X POST -H 'Content-Type: application/json' --data-binary '{"prompt": "\n"}' https://your-workspace-name--modal-nano-gpt-web-generate.modal.run
# ```
# which will return something like:
# ```json
# {
# "output":
# "BRUTUS:
# The broy trefore anny pleasory to
# wip me state of villoor so:
# Fortols listhey for brother beat the else
# Be all, ill of lo-love in igham;
# Ah, here all that queen and hould you father offer"
# }
# ```
# It's not exactly Shakespeare, but at least it shows our model learned something!
# You can choose which model to use by specifying the `experiment_name` in the query parameters of the request URL.
# ### Serving a Gradio UI with `asgi_app`
# Second, we create a Gradio web app for generating text via a graphical user interface in the browser.
# That way our fellow team members and stakeholders can easily interact with the model and give feedback,
# even when we're still training the model.
# You should see the URL for this UI in the output of `modal deploy`
# or on your [Modal app dashboard](https://modal.com/apps) for this app.
# The Gradio UI will look something like this:
# 
@app.function(
image=ui_image,
concurrency_limit=1,
volumes={volume_path: volume},
allow_concurrent_inputs=1000,
)
@modal.asgi_app()
def ui():
import gradio as gr
from fastapi import FastAPI
from fastapi.responses import FileResponse
from gradio.routes import mount_gradio_app
# call out to the inference in a separate Modal environment with a GPU
def generate(text="", experiment_name=""):
if not text:
text = "\n"
generated = ModelInference(
experiment_name=experiment_name
).generate.remote(text)
return text + generated
example_prompts = [
"DUKE OF YORK:\nWhere art thou Lucas?",
"ROMEO:\nWhat is a man?",
"CLARENCE:\nFair is foul and foul is fair, but who are you?",
"Brevity is the soul of wit, so what is the soul of foolishness?",
]
web_app = FastAPI()
# custom styles: an icon, a background, and a theme
@web_app.get("/favicon.ico", include_in_schema=False)
async def favicon():
return FileResponse("/assets/favicon.svg")
@web_app.get("/assets/background.svg", include_in_schema=False)
async def background():
return FileResponse("/assets/background.svg")
with open("/assets/index.css") as f:
css = f.read()
n_last = 20
experiment_names = (
ModelInference().get_latest_available_experiment_names.remote(n_last)
)
theme = gr.themes.Default(
primary_hue="green", secondary_hue="emerald", neutral_hue="neutral"
)
# add a Gradio UI around inference
with gr.Blocks(theme=theme, css=css, title="SLM") as interface:
# title
gr.Markdown("# GPT-style Shakespeare text generation.")
# Model Selection
with gr.Row():
gr.Markdown("## Model Version")
with gr.Row():
experiment_dropdown = gr.Dropdown(
experiment_names, label="Select Model Version"
)
# input and output
with gr.Row():
with gr.Column():
gr.Markdown("## Input:")
input_box = gr.Textbox( # input text component
label="",
placeholder="Write some Shakespeare like text or keep it empty!",
lines=10,
)
with gr.Column():
gr.Markdown("## Output:")
output_box = gr.Textbox( # output text component
label="",
lines=10,
)
# button to trigger inference and a link to Modal
with gr.Row():
generate_button = gr.Button("Generate", variant="primary", scale=2)
generate_button.click(
fn=generate,
inputs=[input_box, experiment_dropdown],
outputs=output_box,
) # connect inputs and outputs with inference function
gr.Button( # shameless plug
" Powered by Modal",
variant="secondary",
link="https://modal.com",
)
# example prompts
with gr.Column(variant="compact"):
# add in a few examples to inspire users
for ii, prompt in enumerate(example_prompts):
btn = gr.Button(prompt, variant="secondary")
btn.click(
fn=lambda idx=ii: example_prompts[idx], outputs=input_box
)
# mount for execution on Modal
return mount_gradio_app(
app=web_app,
blocks=interface,
path="/",
)
# ## Addenda
# The remainder of this code is boilerplate.
# ### Training Loop
# There's quite a lot of code for just the training loop! If you'd rather not write this stuff yourself,
# consider a training framework like [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable)
# or [Hugging Face](https://huggingface.co/transformers/main_classes/trainer.html).
def training_loop(
start_step,
n_steps,
n_steps_before_eval,
n_steps_before_checkpoint,
n_eval_steps,
dataset,
tokenizer,
model,
optimizer,
logs_manager,
checkpoint,
checkpoint_path,
run_to_first_save,
):
@torch.no_grad()
def eval_model(model, dataset, tokenizer, n_eval_steps):
"""Evaluate model on train and validation data."""
out = {}
model.eval() # Turn off gradients
for split in ("train", "val"):
losses = torch.zeros(n_eval_steps)
for k in range(n_eval_steps):
xb, yb = dataset.get_batch(split)
logits, loss = model.forward(xb, yb)
losses[k] = loss
out[split] = losses.mean()
# Generate some output samples
out["sample"] = model.generate_from_text(tokenizer, "\n", 1000)
model.train() # Turn on gradients
return out
t_last = timer()
for step in range(start_step, n_steps + 1):
# sample a batch of data
xb, yb = dataset.get_batch("train")
# evaluate the loss, calculate & apply gradients
logits, loss = model.forward(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# log training loss
logs_manager.add_train_scalar("Cross Entropy Loss", loss.item(), step)
# evaluate model on validation set
if step % n_steps_before_eval == 0:
out = eval_model(model, dataset, tokenizer, n_eval_steps)
log_evals(out, step, t_last, logs_manager)
t_last = timer()
# save model with checkpoint information
if step > 0 and step % n_steps_before_checkpoint == 0:
checkpoint["steps"] = step
checkpoint["val_loss"] = out["val"]
# mark as finished if we hit n steps.
checkpoint["finished_training"] = step >= n_steps
L.info(
f"Saving checkpoint to {checkpoint_path}"
f"\t {checkpoint['finished_training']})"
)
save_checkpoint(checkpoint, checkpoint_path)
if run_to_first_save:
L.info("Stopping early...")
break
return out
def save_checkpoint(checkpoint, checkpoint_path):
torch.save(checkpoint, checkpoint_path)
volume.commit()
def build_model(hparams, vocab_size, device):
"""Initialize the model and move it to the device."""
model = AttentionModel(vocab_size, hparams, device)
model.to(device)
return model
def setup_optimizer(model, learning_rate):
"""Set up the optimizer for the model."""
return torch.optim.AdamW(model.parameters(), lr=learning_rate)
# ### Miscellaneous
# The remaining code includes small helper functions for training the model.
def prepare_data(input_file_path: Path, volume: modal.Volume) -> str:
"""Download and read the dataset."""
volume.reload()
if not input_file_path.exists():
L.info("Downloading Shakespeare dataset...")
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
urllib.request.urlretrieve(data_url, input_file_path)
volume.commit()
return input_file_path.read_text()
def make_best_symbolic_link(model_save_dir, model_filename, experiment_name):
# create symlink to the best model so it's easy to find for web serving
os.symlink(
str(model_save_dir / model_filename),
str(model_save_path / experiment_name / best_model_filename),
)
volume.commit() # commit the symlink
def init_checkpoint(model, tokenizer, optimizer, start_step, hparams):
return {
"model": model.state_dict(),
"unique_chars": tokenizer.unique_chars,
"optimizer": optimizer.state_dict(),
"val_loss": float("inf"),
"steps": start_step,
"hparams": hparams,
"finished_training": False,
}
def log_evals(result, step, t_last, logs_manager):
runtime_s = timer() - t_last
L.info(
f"{step:5d}) // {runtime_s:>5.2f}s"
f" // Train Loss: {result['train']:.2f} // Val Loss:"
f" {result['val']:.2f}"
)
logs_manager.add_val_scalar("Cross Entropy Loss", result["val"], step)
logs_manager.add_val_text("Sample Output", result["sample"], step)
logs_manager.flush()
volume.commit() # Make sure TensorBoard container will see it.
return result