Skip to content

Commit ff6dd19

Browse files
author
Hossein Kavianihamedani
committed
Adding refrences
1 parent 815a44f commit ff6dd19

File tree

1 file changed

+90
-63
lines changed

1 file changed

+90
-63
lines changed

apps/sft/interactive_config_notebook.ipynb

Lines changed: 90 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"\n",
1313
"This notebook tells the complete story of how SFT training works:\n",
1414
"\n",
15-
"1. **🎭 The Actor Model** - Understanding TrainerActor\n",
15+
"1. **🎭 The Actor Model** - Understanding TrainerActor (built on Monarch)\n",
1616
"2. **🔧 Setup Phase** - Loading models, data, and checkpoints\n",
1717
"3. **🏃 Training Loop** - Forward passes, backprop, optimization\n",
1818
"5. **🧹 Cleanup** - Saving checkpoints and releasing resources\n",
@@ -21,14 +21,28 @@
2121
"\n",
2222
"## The Forge Actor Architecture\n",
2323
"\n",
24+
"### What is Monarch?\n",
25+
"\n",
26+
"**Monarch** is Meta's distributed actor framework that powers Forge:\n",
27+
"- 🌐 **Distributed by design** - Built for multi-node, multi-GPU training\n",
28+
"- 🎭 **Actor model** - Encapsulates distributed processes as actors\n",
29+
"- 📡 **Remote communication** - Seamless RPC between actors\n",
30+
"- 🔧 **Lifecycle management** - Spawn → Setup → Run → Cleanup pattern\n",
31+
"\n",
32+
"Forge leverages Monarch to abstract away distributed training complexity!\n",
33+
"\n",
34+
"For more information on Monarch, visit https://github.com/meta-pytorch/monarch/tree/main/docs\n",
35+
"\n",
2436
"### What is a TrainerActor?\n",
2537
"\n",
26-
"Think of a **TrainerActor** as the conductor of an orchestra:\n",
38+
"A **TrainerActor** is Forge's Monarch actor for training:\n",
2739
"- 🎭 **Manages multiple processes** across GPUs or nodes\n",
28-
"- 🔧 **Controls the lifecycle** of training (setup → train → cleanup)\n",
40+
"- 🔧 **Controls the lifecycle** using Monarch's actor pattern\n",
2941
"- 📊 **Coordinates distributed training** with FSDP, tensor parallelism, etc.\n",
3042
"\n",
31-
"### The Training Journey\n",
43+
"Think of it as the conductor of an orchestra - coordinating 8 GPU processes working together!\n",
44+
"\n",
45+
"### The Training Journey (Monarch Actor Lifecycle)\n",
3246
"\n",
3347
"```\n",
3448
"┌─────────────────────────────────────────┐\n",
@@ -37,19 +51,19 @@
3751
"└──────────────┬──────────────────────────┘\n",
3852
"\n",
3953
"┌─────────────────────────────────────────┐\n",
40-
"│ 2. Spawn Actor 🎭 │ ← Forge creates distributed processes\n",
54+
"│ 2. Spawn Actor 🎭 [MONARCH] │ ← Monarch creates distributed processes\n",
4155
"│ (launch 8 GPU processes) │\n",
4256
"└──────────────┬──────────────────────────┘\n",
4357
"\n",
4458
"┌─────────────────────────────────────────┐\n",
45-
"│ 3. Setup Phase 🔧 │ ← Load model, data, checkpoints\n",
59+
"│ 3. Setup Phase 🔧 [MONARCH] │ ← Actor.setup() endpoint\n",
4660
"│ - Initialize model with FSDP │\n",
47-
"│ - Load training dataset │\n",
61+
"│ - Load training dataset │\n",
4862
"│ - Restore from checkpoint (if any) │\n",
4963
"└──────────────┬──────────────────────────┘\n",
5064
"\n",
5165
"┌─────────────────────────────────────────┐\n",
52-
"│ 4. Training Loop 🔄 │ ← The main training process\n",
66+
"│ 4. Training Loop 🔄 [MONARCH] │ ← Actor.train() endpoint\n",
5367
"│ FOR each step: │\n",
5468
"│ → Get batch from dataloader │\n",
5569
"│ → Forward pass (compute loss) │\n",
@@ -60,7 +74,7 @@
6074
"└──────────────┬──────────────────────────┘\n",
6175
"\n",
6276
"┌─────────────────────────────────────────┐\n",
63-
"│ 5. Cleanup Phase 🧹 │ ← Save final state\n",
77+
"│ 5. Cleanup Phase 🧹 [MONARCH] │ ← Actor.cleanup() endpoint\n",
6478
"│ - Save final checkpoint │\n",
6579
"│ - Release GPU memory │\n",
6680
"│ - Stop all processes │\n",
@@ -69,11 +83,13 @@
6983
"\n",
7084
"### Why This Architecture?\n",
7185
"\n",
72-
"✅ **Automatic Distribution** - Forge handles multi-GPU/multi-node complexity \n",
86+
"✅ **Automatic Distribution** - Monarch handles multi-GPU/multi-node complexity \n",
7387
"✅ **Fault Tolerance** - Checkpointing enables recovery from failures \n",
7488
"✅ **Flexibility** - Easy to switch between 1 GPU, 8 GPUs, or multiple nodes \n",
75-
"✅ **Production-Ready** - Used at Meta for large-scale training\n",
89+
"✅ **Production-Ready** - Used at Meta for large-scale training \n",
90+
"✅ **Actor Pattern** - Clean separation of concerns with lifecycle methods\n",
7691
"\n",
92+
"#### For more information regarding Forge visit: https://github.com/meta-pytorch/torchforge/tree/main/docs\n",
7793
"---\n",
7894
"\n",
7995
"Let's configure your training!"
@@ -502,23 +518,32 @@
502518
"cell_type": "markdown",
503519
"metadata": {},
504520
"source": [
505-
"### Phase 2: Setup 🔧\n",
521+
"### Phase 2: Setup 🔧 [Monarch Endpoint]\n",
506522
"\n",
507523
"**What's happening:**\n",
508-
"- **Model Loading**: Each process loads its shard of the model\n",
509-
" - With FSDP, GPU 0 might get layers 0-10\n",
524+
"\n",
525+
"Monarch calls the `@endpoint` decorated `setup()` method on all 8 actor instances:\n",
526+
"\n",
527+
"```python\n",
528+
"class TrainerActor:\n",
529+
" @endpoint\n",
530+
" async def setup(self):\n",
531+
" # This runs on all 8 GPUs simultaneously\n",
532+
" ...\n",
533+
"```\n",
534+
"\n",
535+
"Each actor instance:\n",
536+
"- **Loads its shard of the model**: With FSDP, each GPU only loads ~1/8th\n",
537+
" - GPU 0 might get layers 0-10\n",
510538
" - GPU 1 gets layers 11-20, etc.\n",
511-
" - Each GPU only holds ~1/8th of the full model\n",
512-
"- **Dataset Loading**: Training and validation dataloaders created\n",
513-
" - Same dataset, but different random seeds per GPU\n",
514-
" - Ensures each GPU sees different data\n",
515-
"- **Checkpoint Loading**: If resuming, restore training state\n",
516-
" - Model weights, optimizer state, current step number\n",
539+
"- **Creates dataloaders**: Same dataset, different random seeds per GPU\n",
540+
"- **Restores checkpoint**: If resuming, loads saved state\n",
517541
"\n",
518542
"**What `setup()` does internally:**\n",
519543
"```python\n",
520-
"def setup(self):\n",
521-
" # 1. Initialize model with FSDP\n",
544+
"@endpoint\n",
545+
"async def setup(self):\n",
546+
" # 1. Initialize model with FSDP sharding\n",
522547
" self.model = load_model_with_fsdp(cfg.model)\n",
523548
" \n",
524549
" # 2. Create training dataloader\n",
@@ -537,18 +562,18 @@
537562
" self.checkpointer.load(step=self.current_step)\n",
538563
"```\n",
539564
"\n",
540-
"After setup, all 8 GPUs are synchronized and ready to train!"
565+
"**Monarch magic:**\n",
566+
"- The `@endpoint` decorator makes this method callable remotely\n",
567+
"- Monarch ensures all 8 actors complete setup before proceeding\n",
568+
"- Distributed state (model shards) automatically synchronized\n",
569+
"\n",
570+
"After setup, all 8 GPU actors are synchronized and ready to train!"
541571
]
542572
},
543573
{
544574
"cell_type": "code",
545575
"execution_count": null,
546-
"metadata": {
547-
"output": {
548-
"id": 693658349895675,
549-
"loadingStatus": "loaded"
550-
}
551-
},
576+
"metadata": {},
552577
"outputs": [],
553578
"source": [
554579
"# Setup (load data, checkpoints, etc.)\n",
@@ -560,38 +585,40 @@
560585
"cell_type": "markdown",
561586
"metadata": {},
562587
"source": [
563-
"### Phase 3: Training Loop 🔄\n",
588+
"### Phase 3: Training Loop 🔄 [Monarch Endpoint]\n",
564589
"\n",
565590
"**What's happening:**\n",
566591
"\n",
567-
"The training loop runs for `cfg.training.steps` iterations. Each step:\n",
592+
"Monarch calls the `@endpoint` decorated `train()` method, which runs the training loop for `cfg.training.steps` iterations. Each step:\n",
568593
"\n",
569594
"```python\n",
570-
"for step in range(current_step, max_steps):\n",
571-
" # 1. Get next batch from dataloader\n",
572-
" batch = next(train_dataloader)\n",
573-
" # Shape: [batch_size, seq_len] per GPU\n",
574-
" \n",
575-
" # 2. Forward pass - compute predictions and loss\n",
576-
" outputs = model(batch['input_ids'])\n",
577-
" loss = compute_loss(outputs, batch['labels'])\n",
578-
" \n",
579-
" # 3. Backward pass - compute gradients\n",
580-
" loss.backward()\n",
581-
" # FSDP automatically synchronizes gradients across all GPUs!\n",
582-
" \n",
583-
" # 4. Optimizer step - update model weights\n",
584-
" optimizer.step()\n",
585-
" optimizer.zero_grad()\n",
586-
" \n",
587-
" # 5. Periodic validation (if enabled)\n",
588-
" if validation_enabled and step % eval_interval == 0:\n",
589-
" val_metrics = evaluate()\n",
590-
" log(f\"Step {step}: Val Loss = {val_metrics['val_loss']}\")\n",
591-
" \n",
592-
" # 6. Periodic checkpointing\n",
593-
" if step % checkpoint_interval == 0:\n",
594-
" save_checkpoint(step)\n",
595+
"@endpoint\n",
596+
"async def train(self):\n",
597+
" for step in range(current_step, max_steps):\n",
598+
" # 1. Get next batch from dataloader\n",
599+
" batch = next(train_dataloader)\n",
600+
" # Shape: [batch_size, seq_len] per GPU\n",
601+
"\n",
602+
" # 2. Forward pass - compute predictions and loss\n",
603+
" outputs = model(batch['input_ids'])\n",
604+
" loss = compute_loss(outputs, batch['labels'])\n",
605+
"\n",
606+
" # 3. Backward pass - compute gradients\n",
607+
" loss.backward()\n",
608+
" # FSDP automatically synchronizes gradients across all GPUs!\n",
609+
"\n",
610+
" # 4. Optimizer step - update model weights\n",
611+
" optimizer.step()\n",
612+
" optimizer.zero_grad()\n",
613+
"\n",
614+
" # 5. Periodic validation (if enabled)\n",
615+
" if validation_enabled and step % eval_interval == 0:\n",
616+
" val_metrics = evaluate()\n",
617+
" log(f\"Step {step}: Val Loss = {val_metrics['val_loss']}\")\n",
618+
"\n",
619+
" # 6. Periodic checkpointing\n",
620+
" if step % checkpoint_interval == 0:\n",
621+
" save_checkpoint(step)\n",
595622
"```\n",
596623
"\n",
597624
"**Key insights:**\n",
@@ -604,18 +631,18 @@
604631
"- Training loss decreasing over time\n",
605632
"- Periodic validation metrics (if enabled)\n",
606633
"- Checkpoint saves at regular intervals\n",
607-
"- Step timing information (seconds per step)"
634+
"- Step timing information (seconds per step)\n",
635+
"\n",
636+
"**Monarch magic:**\n",
637+
"- The `@endpoint` decorator makes this long-running training loop remotely callable\n",
638+
"- All 8 actor instances run training in sync\n",
639+
"- Monarch handles any RPC timeouts for long-running operations"
608640
]
609641
},
610642
{
611643
"cell_type": "code",
612644
"execution_count": null,
613-
"metadata": {
614-
"output": {
615-
"id": 4257826794454822,
616-
"loadingStatus": "loaded"
617-
}
618-
},
645+
"metadata": {},
619646
"outputs": [],
620647
"source": [
621648
"# Run training\n",

0 commit comments

Comments
 (0)