|
12 | 12 | "\n", |
13 | 13 | "This notebook tells the complete story of how SFT training works:\n", |
14 | 14 | "\n", |
15 | | - "1. **🎭 The Actor Model** - Understanding TrainerActor\n", |
| 15 | + "1. **🎭 The Actor Model** - Understanding TrainerActor (built on Monarch)\n", |
16 | 16 | "2. **🔧 Setup Phase** - Loading models, data, and checkpoints\n", |
17 | 17 | "3. **🏃 Training Loop** - Forward passes, backprop, optimization\n", |
18 | 18 | "5. **🧹 Cleanup** - Saving checkpoints and releasing resources\n", |
|
21 | 21 | "\n", |
22 | 22 | "## The Forge Actor Architecture\n", |
23 | 23 | "\n", |
| 24 | + "### What is Monarch?\n", |
| 25 | + "\n", |
| 26 | + "**Monarch** is Meta's distributed actor framework that powers Forge:\n", |
| 27 | + "- 🌐 **Distributed by design** - Built for multi-node, multi-GPU training\n", |
| 28 | + "- 🎭 **Actor model** - Encapsulates distributed processes as actors\n", |
| 29 | + "- 📡 **Remote communication** - Seamless RPC between actors\n", |
| 30 | + "- 🔧 **Lifecycle management** - Spawn → Setup → Run → Cleanup pattern\n", |
| 31 | + "\n", |
| 32 | + "Forge leverages Monarch to abstract away distributed training complexity!\n", |
| 33 | + "\n", |
| 34 | + "For more information on Monarch, visit https://github.com/meta-pytorch/monarch/tree/main/docs\n", |
| 35 | + "\n", |
24 | 36 | "### What is a TrainerActor?\n", |
25 | 37 | "\n", |
26 | | - "Think of a **TrainerActor** as the conductor of an orchestra:\n", |
| 38 | + "A **TrainerActor** is Forge's Monarch actor for training:\n", |
27 | 39 | "- 🎭 **Manages multiple processes** across GPUs or nodes\n", |
28 | | - "- 🔧 **Controls the lifecycle** of training (setup → train → cleanup)\n", |
| 40 | + "- 🔧 **Controls the lifecycle** using Monarch's actor pattern\n", |
29 | 41 | "- 📊 **Coordinates distributed training** with FSDP, tensor parallelism, etc.\n", |
30 | 42 | "\n", |
31 | | - "### The Training Journey\n", |
| 43 | + "Think of it as the conductor of an orchestra - coordinating 8 GPU processes working together!\n", |
| 44 | + "\n", |
| 45 | + "### The Training Journey (Monarch Actor Lifecycle)\n", |
32 | 46 | "\n", |
33 | 47 | "```\n", |
34 | 48 | "┌─────────────────────────────────────────┐\n", |
|
37 | 51 | "└──────────────┬──────────────────────────┘\n", |
38 | 52 | " ↓\n", |
39 | 53 | "┌─────────────────────────────────────────┐\n", |
40 | | - "│ 2. Spawn Actor 🎭 │ ← Forge creates distributed processes\n", |
| 54 | + "│ 2. Spawn Actor 🎭 [MONARCH] │ ← Monarch creates distributed processes\n", |
41 | 55 | "│ (launch 8 GPU processes) │\n", |
42 | 56 | "└──────────────┬──────────────────────────┘\n", |
43 | 57 | " ↓\n", |
44 | 58 | "┌─────────────────────────────────────────┐\n", |
45 | | - "│ 3. Setup Phase 🔧 │ ← Load model, data, checkpoints\n", |
| 59 | + "│ 3. Setup Phase 🔧 [MONARCH] │ ← Actor.setup() endpoint\n", |
46 | 60 | "│ - Initialize model with FSDP │\n", |
47 | | - "│ - Load training dataset │ │\n", |
| 61 | + "│ - Load training dataset │\n", |
48 | 62 | "│ - Restore from checkpoint (if any) │\n", |
49 | 63 | "└──────────────┬──────────────────────────┘\n", |
50 | 64 | " ↓\n", |
51 | 65 | "┌─────────────────────────────────────────┐\n", |
52 | | - "│ 4. Training Loop 🔄 │ ← The main training process\n", |
| 66 | + "│ 4. Training Loop 🔄 [MONARCH] │ ← Actor.train() endpoint\n", |
53 | 67 | "│ FOR each step: │\n", |
54 | 68 | "│ → Get batch from dataloader │\n", |
55 | 69 | "│ → Forward pass (compute loss) │\n", |
|
60 | 74 | "└──────────────┬──────────────────────────┘\n", |
61 | 75 | " ↓\n", |
62 | 76 | "┌─────────────────────────────────────────┐\n", |
63 | | - "│ 5. Cleanup Phase 🧹 │ ← Save final state\n", |
| 77 | + "│ 5. Cleanup Phase 🧹 [MONARCH] │ ← Actor.cleanup() endpoint\n", |
64 | 78 | "│ - Save final checkpoint │\n", |
65 | 79 | "│ - Release GPU memory │\n", |
66 | 80 | "│ - Stop all processes │\n", |
|
69 | 83 | "\n", |
70 | 84 | "### Why This Architecture?\n", |
71 | 85 | "\n", |
72 | | - "✅ **Automatic Distribution** - Forge handles multi-GPU/multi-node complexity \n", |
| 86 | + "✅ **Automatic Distribution** - Monarch handles multi-GPU/multi-node complexity \n", |
73 | 87 | "✅ **Fault Tolerance** - Checkpointing enables recovery from failures \n", |
74 | 88 | "✅ **Flexibility** - Easy to switch between 1 GPU, 8 GPUs, or multiple nodes \n", |
75 | | - "✅ **Production-Ready** - Used at Meta for large-scale training\n", |
| 89 | + "✅ **Production-Ready** - Used at Meta for large-scale training \n", |
| 90 | + "✅ **Actor Pattern** - Clean separation of concerns with lifecycle methods\n", |
76 | 91 | "\n", |
| 92 | + "#### For more information regarding Forge visit: https://github.com/meta-pytorch/torchforge/tree/main/docs\n", |
77 | 93 | "---\n", |
78 | 94 | "\n", |
79 | 95 | "Let's configure your training!" |
|
502 | 518 | "cell_type": "markdown", |
503 | 519 | "metadata": {}, |
504 | 520 | "source": [ |
505 | | - "### Phase 2: Setup 🔧\n", |
| 521 | + "### Phase 2: Setup 🔧 [Monarch Endpoint]\n", |
506 | 522 | "\n", |
507 | 523 | "**What's happening:**\n", |
508 | | - "- **Model Loading**: Each process loads its shard of the model\n", |
509 | | - " - With FSDP, GPU 0 might get layers 0-10\n", |
| 524 | + "\n", |
| 525 | + "Monarch calls the `@endpoint` decorated `setup()` method on all 8 actor instances:\n", |
| 526 | + "\n", |
| 527 | + "```python\n", |
| 528 | + "class TrainerActor:\n", |
| 529 | + " @endpoint\n", |
| 530 | + " async def setup(self):\n", |
| 531 | + " # This runs on all 8 GPUs simultaneously\n", |
| 532 | + " ...\n", |
| 533 | + "```\n", |
| 534 | + "\n", |
| 535 | + "Each actor instance:\n", |
| 536 | + "- **Loads its shard of the model**: With FSDP, each GPU only loads ~1/8th\n", |
| 537 | + " - GPU 0 might get layers 0-10\n", |
510 | 538 | " - GPU 1 gets layers 11-20, etc.\n", |
511 | | - " - Each GPU only holds ~1/8th of the full model\n", |
512 | | - "- **Dataset Loading**: Training and validation dataloaders created\n", |
513 | | - " - Same dataset, but different random seeds per GPU\n", |
514 | | - " - Ensures each GPU sees different data\n", |
515 | | - "- **Checkpoint Loading**: If resuming, restore training state\n", |
516 | | - " - Model weights, optimizer state, current step number\n", |
| 539 | + "- **Creates dataloaders**: Same dataset, different random seeds per GPU\n", |
| 540 | + "- **Restores checkpoint**: If resuming, loads saved state\n", |
517 | 541 | "\n", |
518 | 542 | "**What `setup()` does internally:**\n", |
519 | 543 | "```python\n", |
520 | | - "def setup(self):\n", |
521 | | - " # 1. Initialize model with FSDP\n", |
| 544 | + "@endpoint\n", |
| 545 | + "async def setup(self):\n", |
| 546 | + " # 1. Initialize model with FSDP sharding\n", |
522 | 547 | " self.model = load_model_with_fsdp(cfg.model)\n", |
523 | 548 | " \n", |
524 | 549 | " # 2. Create training dataloader\n", |
|
537 | 562 | " self.checkpointer.load(step=self.current_step)\n", |
538 | 563 | "```\n", |
539 | 564 | "\n", |
540 | | - "After setup, all 8 GPUs are synchronized and ready to train!" |
| 565 | + "**Monarch magic:**\n", |
| 566 | + "- The `@endpoint` decorator makes this method callable remotely\n", |
| 567 | + "- Monarch ensures all 8 actors complete setup before proceeding\n", |
| 568 | + "- Distributed state (model shards) automatically synchronized\n", |
| 569 | + "\n", |
| 570 | + "After setup, all 8 GPU actors are synchronized and ready to train!" |
541 | 571 | ] |
542 | 572 | }, |
543 | 573 | { |
544 | 574 | "cell_type": "code", |
545 | 575 | "execution_count": null, |
546 | | - "metadata": { |
547 | | - "output": { |
548 | | - "id": 693658349895675, |
549 | | - "loadingStatus": "loaded" |
550 | | - } |
551 | | - }, |
| 576 | + "metadata": {}, |
552 | 577 | "outputs": [], |
553 | 578 | "source": [ |
554 | 579 | "# Setup (load data, checkpoints, etc.)\n", |
|
560 | 585 | "cell_type": "markdown", |
561 | 586 | "metadata": {}, |
562 | 587 | "source": [ |
563 | | - "### Phase 3: Training Loop 🔄\n", |
| 588 | + "### Phase 3: Training Loop 🔄 [Monarch Endpoint]\n", |
564 | 589 | "\n", |
565 | 590 | "**What's happening:**\n", |
566 | 591 | "\n", |
567 | | - "The training loop runs for `cfg.training.steps` iterations. Each step:\n", |
| 592 | + "Monarch calls the `@endpoint` decorated `train()` method, which runs the training loop for `cfg.training.steps` iterations. Each step:\n", |
568 | 593 | "\n", |
569 | 594 | "```python\n", |
570 | | - "for step in range(current_step, max_steps):\n", |
571 | | - " # 1. Get next batch from dataloader\n", |
572 | | - " batch = next(train_dataloader)\n", |
573 | | - " # Shape: [batch_size, seq_len] per GPU\n", |
574 | | - " \n", |
575 | | - " # 2. Forward pass - compute predictions and loss\n", |
576 | | - " outputs = model(batch['input_ids'])\n", |
577 | | - " loss = compute_loss(outputs, batch['labels'])\n", |
578 | | - " \n", |
579 | | - " # 3. Backward pass - compute gradients\n", |
580 | | - " loss.backward()\n", |
581 | | - " # FSDP automatically synchronizes gradients across all GPUs!\n", |
582 | | - " \n", |
583 | | - " # 4. Optimizer step - update model weights\n", |
584 | | - " optimizer.step()\n", |
585 | | - " optimizer.zero_grad()\n", |
586 | | - " \n", |
587 | | - " # 5. Periodic validation (if enabled)\n", |
588 | | - " if validation_enabled and step % eval_interval == 0:\n", |
589 | | - " val_metrics = evaluate()\n", |
590 | | - " log(f\"Step {step}: Val Loss = {val_metrics['val_loss']}\")\n", |
591 | | - " \n", |
592 | | - " # 6. Periodic checkpointing\n", |
593 | | - " if step % checkpoint_interval == 0:\n", |
594 | | - " save_checkpoint(step)\n", |
| 595 | + "@endpoint\n", |
| 596 | + "async def train(self):\n", |
| 597 | + " for step in range(current_step, max_steps):\n", |
| 598 | + " # 1. Get next batch from dataloader\n", |
| 599 | + " batch = next(train_dataloader)\n", |
| 600 | + " # Shape: [batch_size, seq_len] per GPU\n", |
| 601 | + "\n", |
| 602 | + " # 2. Forward pass - compute predictions and loss\n", |
| 603 | + " outputs = model(batch['input_ids'])\n", |
| 604 | + " loss = compute_loss(outputs, batch['labels'])\n", |
| 605 | + "\n", |
| 606 | + " # 3. Backward pass - compute gradients\n", |
| 607 | + " loss.backward()\n", |
| 608 | + " # FSDP automatically synchronizes gradients across all GPUs!\n", |
| 609 | + "\n", |
| 610 | + " # 4. Optimizer step - update model weights\n", |
| 611 | + " optimizer.step()\n", |
| 612 | + " optimizer.zero_grad()\n", |
| 613 | + "\n", |
| 614 | + " # 5. Periodic validation (if enabled)\n", |
| 615 | + " if validation_enabled and step % eval_interval == 0:\n", |
| 616 | + " val_metrics = evaluate()\n", |
| 617 | + " log(f\"Step {step}: Val Loss = {val_metrics['val_loss']}\")\n", |
| 618 | + "\n", |
| 619 | + " # 6. Periodic checkpointing\n", |
| 620 | + " if step % checkpoint_interval == 0:\n", |
| 621 | + " save_checkpoint(step)\n", |
595 | 622 | "```\n", |
596 | 623 | "\n", |
597 | 624 | "**Key insights:**\n", |
|
604 | 631 | "- Training loss decreasing over time\n", |
605 | 632 | "- Periodic validation metrics (if enabled)\n", |
606 | 633 | "- Checkpoint saves at regular intervals\n", |
607 | | - "- Step timing information (seconds per step)" |
| 634 | + "- Step timing information (seconds per step)\n", |
| 635 | + "\n", |
| 636 | + "**Monarch magic:**\n", |
| 637 | + "- The `@endpoint` decorator makes this long-running training loop remotely callable\n", |
| 638 | + "- All 8 actor instances run training in sync\n", |
| 639 | + "- Monarch handles any RPC timeouts for long-running operations" |
608 | 640 | ] |
609 | 641 | }, |
610 | 642 | { |
611 | 643 | "cell_type": "code", |
612 | 644 | "execution_count": null, |
613 | | - "metadata": { |
614 | | - "output": { |
615 | | - "id": 4257826794454822, |
616 | | - "loadingStatus": "loaded" |
617 | | - } |
618 | | - }, |
| 645 | + "metadata": {}, |
619 | 646 | "outputs": [], |
620 | 647 | "source": [ |
621 | 648 | "# Run training\n", |
|
0 commit comments