diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css
index 89854cc8b..5f2d0897b 100644
--- a/docs/source/_static/custom.css
+++ b/docs/source/_static/custom.css
@@ -1,3 +1,62 @@
+/* Center all Mermaid diagrams */
+.mermaid {
+ display: block;
+ margin-left: auto;
+ margin-right: auto;
+ text-align: center;
+}
+
+/* Center the pre element that contains mermaid diagrams */
+pre.mermaid {
+ display: flex;
+ justify-content: center;
+}
+
+/* Adjust Mermaid line colors based on theme */
+/* Light mode - darker lines for visibility on white background */
+html[data-theme="light"] .mermaid .edgePath .path,
+html[data-theme="light"] .mermaid .flowchart-link {
+ stroke: #555 !important;
+ stroke-width: 2px !important;
+}
+
+/* Light mode - darker arrow tips */
+html[data-theme="light"] .mermaid .arrowheadPath,
+html[data-theme="light"] .mermaid marker path {
+ fill: #555 !important;
+ stroke: #555 !important;
+}
+
+html[data-theme="dark"] .mermaid .arrowheadPath,
+html[data-theme="dark"] .mermaid marker path {
+ fill: #aaa !important;
+ stroke: #aaa !important;
+}
+
+/* Dark mode - lighter lines for visibility on dark background */
+html[data-theme="dark"] .mermaid .edgePath .path,
+html[data-theme="dark"] .mermaid .flowchart-link {
+ stroke: #aaa !important;
+ stroke-width: 2px !important;
+}
+
+/* Dark mode - lighter arrow tips */
+html[data-theme="dark"] .mermaid .arrowheadPath,
+html[data-theme="dark"] .mermaid marker path {
+ fill: #aaa !important;
+ stroke: #aaa !important;
+}
+
+/* Adjust edge labels background based on theme */
+html[data-theme="light"] .mermaid .edgeLabel {
+ background-color: #fff !important;
+}
+
+html[data-theme="dark"] .mermaid .edgeLabel {
+ background-color: #1a1a1a !important;
+ color: #fff !important;
+}
+
/* Custom CSS for collapsible parameter lists */
/* Hide parameters in signatures */
diff --git a/docs/source/_static/custom.js b/docs/source/_static/custom.js
index 415592d30..fa794ae89 100644
--- a/docs/source/_static/custom.js
+++ b/docs/source/_static/custom.js
@@ -1,3 +1,103 @@
+// Lightbox functionality for Mermaid diagrams
+document.addEventListener('DOMContentLoaded', function() {
+ // Create lightbox modal for Mermaid diagrams
+ const lightbox = document.createElement('div');
+ lightbox.id = 'mermaid-lightbox';
+ lightbox.style.cssText = `
+ display: none;
+ position: fixed;
+ z-index: 9999;
+ left: 0;
+ top: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgba(0,0,0,0.9);
+ cursor: zoom-out;
+ `;
+
+ const lightboxContent = document.createElement('div');
+ lightboxContent.style.cssText = `
+ position: absolute;
+ top: 50%;
+ left: 50%;
+ transform: translate(-50%, -50%);
+ max-width: 95%;
+ max-height: 95%;
+ overflow: auto;
+ `;
+
+ const closeBtn = document.createElement('span');
+ closeBtn.innerHTML = '×';
+ closeBtn.style.cssText = `
+ position: absolute;
+ top: 15px;
+ right: 35px;
+ color: #f1f1f1;
+ font-size: 40px;
+ font-weight: bold;
+ cursor: pointer;
+ z-index: 10000;
+ `;
+ closeBtn.onclick = function() {
+ lightbox.style.display = 'none';
+ };
+
+ lightbox.appendChild(closeBtn);
+ lightbox.appendChild(lightboxContent);
+ document.body.appendChild(lightbox);
+
+ // Click outside to close
+ lightbox.onclick = function(e) {
+ if (e.target === lightbox) {
+ lightbox.style.display = 'none';
+ }
+ };
+
+ // ESC key to close
+ document.addEventListener('keydown', function(e) {
+ if (e.key === 'Escape' && lightbox.style.display === 'block') {
+ lightbox.style.display = 'none';
+ }
+ });
+
+ // Make all Mermaid diagrams clickable
+ const mermaidDivs = document.querySelectorAll('.mermaid');
+ mermaidDivs.forEach(function(div) {
+ div.style.cursor = 'zoom-in';
+ div.title = 'Click to enlarge';
+
+ div.addEventListener('click', function() {
+ const clone = div.cloneNode(true);
+
+ // Style the cloned diagram to fill the lightbox
+ clone.style.cssText = `
+ cursor: default;
+ width: 90vw;
+ max-width: 1400px;
+ height: auto;
+ margin: auto;
+ `;
+
+ // Find and resize the SVG inside
+ const svg = clone.querySelector('svg');
+ if (svg) {
+ svg.style.cssText = `
+ width: 100% !important;
+ height: auto !important;
+ max-width: none !important;
+ max-height: 90vh !important;
+ `;
+ svg.removeAttribute('width');
+ svg.removeAttribute('height');
+ }
+
+ lightboxContent.innerHTML = '';
+ lightboxContent.appendChild(clone);
+ lightbox.style.display = 'block';
+ });
+ });
+});
+
// Custom JavaScript to make long parameter lists in class signatures collapsible
document.addEventListener('DOMContentLoaded', function() {
console.log('Collapsible parameters script loaded');
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 2ee6771ea..c0f8280d9 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -178,6 +178,30 @@ def get_version_path():
# Configure MyST parser to treat mermaid code blocks as mermaid directives
myst_fence_as_directive = ["mermaid"]
+# Disable D3 zoom (we'll use lightbox instead)
+mermaid_d3_zoom = False
+
+# Global Mermaid theme configuration - applies to all diagrams
+mermaid_init_js = """
+import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@11.2.0/dist/mermaid.esm.min.mjs';
+mermaid.initialize({
+ startOnLoad: false,
+ theme: 'base',
+ themeVariables: {
+ primaryColor: '#4CAF50',
+ primaryTextColor: '#000',
+ primaryBorderColor: '#fff',
+ lineColor: '#555',
+ secondaryColor: '#FF9800',
+ tertiaryColor: '#ffffde'
+ },
+ flowchart: {
+ curve: 'basis'
+ },
+ themeCSS: '.edgePath .path { stroke-width: 4px; stroke: #555; }'
+});
+"""
+
autodoc_default_options = {
"members": True,
"undoc-members": True,
diff --git a/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md b/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md
index ae74df101..1204366e6 100644
--- a/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md
+++ b/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md
@@ -11,8 +11,10 @@ graph TD
subgraph Example["Math Tutoring RL Example"]
Dataset["Dataset: math problems"]
Policy["Policy: student AI"]
- Reward["Reward Model: scores answers"]
- Reference["Reference Model: baseline"]
+ Reward["Reward Model:
+ scores answers"]
+ Reference["Reference Model:
+ baseline"]
ReplayBuffer["Replay Buffer: stores experiences"]
Trainer["Trainer: improves student"]
end
@@ -25,9 +27,11 @@ graph TD
ReplayBuffer --> Trainer
Trainer --> Policy
- style Policy fill:#4CAF50
- style Reward fill:#FF9800
- style Trainer fill:#E91E63
+ style Policy fill:#4CAF50,stroke:#fff,stroke-width:2px
+ style Reward fill:#FF9800,stroke:#fff,stroke-width:2px
+ style Trainer fill:#E91E63,stroke:#fff,stroke-width:2px
+
+ linkStyle default stroke:#888,stroke-width:2px
```
### RL Components Defined (TorchForge Names)
@@ -76,7 +80,7 @@ Here's the key insight: **Each RL component becomes a TorchForge service**. The
```mermaid
graph LR
subgraph Concepts["RL Concepts"]
- direction TB
+
C1["Dataset"]
C2["Policy"]
C3["Reward Model"]
@@ -86,7 +90,7 @@ graph LR
end
subgraph Services["TorchForge Services (Real Classes)"]
- direction TB
+
S1["DatasetActor"]
S2["Generator"]
S3["RewardActor"]
@@ -173,11 +177,20 @@ Our simple RL loop above has complex requirements:
```mermaid
graph LR
- A["Policy: Student AI
'What is 2+2?' → 'The answer is 4'"]
- B["Reward: Teacher
Scores answer: 0.95"]
- C["Reference: Original Student
Provides baseline comparison"]
- D["Replay Buffer: Notebook
Stores: question + answer + score"]
- E["Trainer: Tutor
Improves student using experiences"]
+ A["Policy: Student AI
+ 'What is 2+2?' →
+ 'The answer is 4'"]
+ B["Reward: Teacher
+ Scores answer: 0.95"]
+ C["Reference: Original Student
+ Provides baseline comparison"]
+ D["Replay Buffer: Notebook
+ Stores: question
+ + answer
+ + score"]
+ E["Trainer: Tutor
+ Improves student
+ using experiences"]
A --> B
A --> C
diff --git a/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md b/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md
index 335c5fc5a..1dabe0d5f 100644
--- a/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md
+++ b/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md
@@ -12,22 +12,35 @@ When you call `await policy_service.generate(question)`, here's what actually ha
```mermaid
graph TD
- Call["Your Code:
await policy_service.generate"]
+ Call["Your Code:
+ await policy_service
+ .generate.route"]
subgraph ServiceLayer["Service Layer"]
- Proxy["Service Proxy: Load balancing, Health checking"]
- LB["Load Balancer: Replica selection, Circuit breaker"]
+ Proxy["Service Proxy:
+ Load balancing
+ Health checking"]
+ LB["Load Balancer:
+ Replica selection
+ Circuit breaker"]
end
subgraph Replicas["Replica Management"]
- R1["Replica 1: GPU 0, Healthy"]
- R2["Replica 2: GPU 1, Overloaded"]
- R3["Replica 3: GPU 2, Failed"]
- R4["Replica 4: GPU 3, Healthy"]
+ R1["Replica 1:
+ GPU 0, Healthy"]
+ R2["Replica 2:
+ GPU 1, Overloaded"]
+ R3["Replica 3:
+ GPU 2, Failed"]
+ R4["Replica 4:
+ GPU 3, Healthy"]
end
subgraph Compute["Actual Computation"]
- Actor["Policy Actor: vLLM engine, Model weights, KV cache"]
+ Actor["Policy Actor:
+ vLLM engine,
+ Model weights,
+ KV cache"]
end
Call --> Proxy
@@ -126,13 +139,17 @@ responses = await policy.generate.route(prompt=prompt)
```mermaid
graph LR
subgraph Request["Your Request"]
- Code["await service.method.ADVERB()"]
+ Code["await service
+ .method.ADVERB()"]
end
subgraph Patterns["Communication Patterns"]
- Route[".route()
→ One healthy replica"]
- CallOne[".call_one()
→ Single actor"]
- Fanout[".fanout()
→ ALL replicas"]
+ Route[".route()
+ → One healthy replica"]
+ CallOne[".call_one()
+ → Single actor"]
+ Fanout[".fanout()
+ → ALL replicas"]
end
subgraph Replicas["Replicas/Actors"]
diff --git a/docs/source/tutorial_sources/zero-to-forge/3_Monarch_101.md b/docs/source/tutorial_sources/zero-to-forge/3_Monarch_101.md
index 8a53566c0..59163d6c7 100644
--- a/docs/source/tutorial_sources/zero-to-forge/3_Monarch_101.md
+++ b/docs/source/tutorial_sources/zero-to-forge/3_Monarch_101.md
@@ -8,25 +8,51 @@ Now let's peel back the layers. TorchForge services are built on top of **Monarc
```mermaid
graph TD
- subgraph YourCode["1. Your RL Code"]
- Call["await policy_service.generate.route('What is 2+2?')"]
+ subgraph YourCode["(1) Your RL Code"]
+ Call["await policy_service
+ .generate.route
+ ('What is 2+2?')"]
end
- subgraph ForgeServices["2. TorchForge Service Layer"]
- ServiceInterface["ServiceInterface: Routes requests, Load balancing, Health checks"]
- ServiceActor["ServiceActor: Manages replicas, Monitors health, Coordinates failures"]
+ subgraph ForgeServices["(2) TorchForge Service Layer"]
+ ServiceInterface["ServiceInterface:
+ Routes requests
+ Load balancing
+ Health checks"]
+ ServiceActor["ServiceActor:
+ Manages replicas
+ Monitors health
+ Coordinates failures"]
end
subgraph MonarchLayer["3. Monarch Actor Layer"]
- ActorMesh["ActorMesh Policy Actor: 4 instances, Different GPUs, Message passing"]
- ProcMesh["ProcMesh: 4 processes, GPU topology 0,1,2,3, Network interconnect"]
+ ActorMesh["ActorMesh Policy Actor:
+ 4 instances
+ Different GPUs
+ Message passing"]
+ ProcMesh["ProcMesh:
+ 4 processes
+ GPU topology 0,1,2,3
+ Network interconnect"]
end
subgraph Hardware["4. Physical Hardware"]
- GPU0["GPU 0: Policy Actor #1, vLLM Engine, Model Weights"]
- GPU1["GPU 1: Policy Actor #2, vLLM Engine, Model Weights"]
- GPU2["GPU 2: Policy Actor #3, vLLM Engine, Model Weights"]
- GPU3["GPU 3: Policy Actor #4, vLLM Engine, Model Weights"]
+ GPU0["GPU 0:
+ Policy Actor #1
+ vLLM Engine
+ Model Weights"]
+ GPU1["GPU 1:
+ Policy Actor #2
+ vLLM Engine
+ Model Weights"]
+ GPU2["GPU 2:
+ Policy Actor #3
+ vLLM Engine
+ Model Weights"]
+ GPU3["GPU 3:
+ Policy Actor #4
+ vLLM Engine
+ Model Weights"]
end
Call --> ServiceInterface
@@ -177,29 +203,49 @@ Now the key insight: **TorchForge services are ServiceActors that manage ActorMe
```mermaid
graph TD
subgraph ServiceCreation["Service Creation Process"]
- Call["await Policy.options(num_replicas=4, procs=1).as_service(model='Qwen')"]
-
- ServiceActor["ServiceActor: Manages 4 replicas, Health checks, Routes calls"]
+ Call["await Policy
+ .options(
+ num_replicas=4,
+ procs=1)
+ .as_service(
+ model='Qwen')"]
+
+ ServiceActor["ServiceActor:
+ Manages 4 replicas
+ Health checks
+ Routes calls"]
subgraph Replicas["4 Independent Replicas"]
subgraph R0["Replica 0"]
- PM0["ProcMesh: 1 process, GPU 0"]
- AM0["ActorMesh
1 Policy Actor"]
+ PM0["ProcMesh:
+ 1 process
+ GPU 0"]
+ AM0["ActorMesh
+ 1 Policy Actor"]
end
subgraph R1["Replica 1"]
- PM1["ProcMesh: 1 process, GPU 1"]
- AM1["ActorMesh
1 Policy Actor"]
+ PM1["ProcMesh:
+ 1 process
+ GPU 1"]
+ AM1["ActorMesh
+ 1 Policy Actor"]
end
subgraph R2["Replica 2"]
- PM2["ProcMesh: 1 process, GPU 2"]
- AM2["ActorMesh
1 Policy Actor"]
+ PM2["ProcMesh:
+ 1 process
+ GPU 2"]
+ AM2["ActorMesh
+ 1 Policy Actor"]
end
subgraph R3["Replica 3"]
- PM3["ProcMesh: 1 process, GPU 3"]
- AM3["ActorMesh
1 Policy Actor"]
+ PM3["ProcMesh:
+ 1 process
+ GPU 3"]
+ AM3["ActorMesh
+ 1 Policy Actor"]
end
end
@@ -224,19 +270,35 @@ graph TD
### Service Call to Actor Execution
```mermaid
+:align: center
graph TD
subgraph CallFlow["Complete Call Flow"]
- UserCall["await policy_service.generate.route('What is 2+2?')"]
+ UserCall["await policy_service
+ .generate.route
+ ('What is 2+2?')"]
- ServiceInterface["ServiceInterface: Receives .route() call, Routes to ServiceActor"]
+ ServiceInterface["ServiceInterface:
+ Receives .route() call
+ Routes to ServiceActor"]
- ServiceActor["ServiceActor: Selects healthy replica, Load balancing, Failure handling"]
+ ServiceActor["ServiceActor:
+ Selects healthy replica
+ Load balancing
+ Failure handling"]
- SelectedReplica["Selected Replica #2: ProcMesh 1 process, ActorMesh 1 Policy Actor"]
+ SelectedReplica["Selected Replica #2:
+ ProcMesh 1 process
+ ActorMesh 1 Policy Actor"]
- PolicyActor["Policy Actor Instance: Loads model, Runs vLLM inference"]
+ PolicyActor["Policy Actor Instance:
+ Loads model
+ Runs vLLM inference"]
- GPU["GPU 2: vLLM engine, Model weights, KV cache, CUDA kernels"]
+ GPU["GPU 2:
+ vLLM engine
+ Model weights
+ KV cache
+ CUDA kernels"]
UserCall --> ServiceInterface
ServiceInterface --> ServiceActor
@@ -265,28 +327,28 @@ In real RL systems, you have multiple services that can share or use separate Pr
graph TD
subgraph Cluster["RL Training Cluster"]
subgraph Services["TorchForge Services"]
- PS["Policy Service
4 GPU replicas"]
- TS["Trainer Service
2 GPU replicas"]
- RS["Reward Service
4 CPU replicas"]
- BS["Buffer Service
1 CPU replica"]
+ PS["Policy Service - 4 GPU replicas"]
+ TS["Trainer Service - 2 GPU replicas"]
+ RS["Reward Service - 4 CPU replicas"]
+ BS["Buffer Service - 1 CPU replica"]
end
subgraph MonarchInfra["Monarch Infrastructure"]
subgraph GPUMesh["GPU ProcMesh (6 processes)"]
- G0["Process 0
GPU 0"]
- G1["Process 1
GPU 1"]
- G2["Process 2
GPU 2"]
- G3["Process 3
GPU 3"]
- G4["Process 4
GPU 4"]
- G5["Process 5
GPU 5"]
+ G0["Process 0 - GPU 0"]
+ G1["Process 1 - GPU 1"]
+ G2["Process 2 - GPU 2"]
+ G3["Process 3 - GPU 3"]
+ G4["Process 4 - GPU 4"]
+ G5["Process 5 - GPU 5"]
end
subgraph CPUMesh["CPU ProcMesh (5 processes)"]
- C0["Process 0
CPU"]
- C1["Process 1
CPU"]
- C2["Process 2
CPU"]
- C3["Process 3
CPU"]
- C4["Process 4
CPU"]
+ C0["Process 0 - CPU"]
+ C1["Process 1 - CPU"]
+ C2["Process 2 - CPU"]
+ C3["Process 3 - CPU"]
+ C4["Process 4 - CPU"]
end
end
diff --git a/docs/source/zero-to-forge-intro.md b/docs/source/zero-to-forge-intro.md
index c9f2e98d2..7e815c83b 100644
--- a/docs/source/zero-to-forge-intro.md
+++ b/docs/source/zero-to-forge-intro.md
@@ -12,9 +12,9 @@ PyTorch tutorial, shoutout to our PyTorch friends that remember!
This section currently is structured in 3 detailed parts:
-1. [RL Fundamentals and Understanding TorchForge Terminology](tutorials/zero-to-forge/1_RL_and_Forge_Fundamentals): This gives a quick refresher of Reinforcement Learning and teaches you TorchForge Fundamentals
-2. [TorchForge Internals](tutorials/zero-to-forge/2_Forge_Internals): Goes a layer deeper and explains the internals of TorchForge
-3. [Monarch 101](tutorials/zero-to-forge/3_Monarch_101): It's a 101 to Monarch and how TorchForge Talks to Monarch
+1. [Part 1: RL Fundamentals - Using TorchForge Terminology](tutorials/zero-to-forge/1_RL_and_Forge_Fundamentals): This gives a quick refresher of Reinforcement Learning and teaches you TorchForge Fundamentals
+2. [Part 2: Peeling Back the Abstraction - What Are Services?](tutorials/zero-to-forge/2_Forge_Internals): Goes a layer deeper and explains the internals of TorchForge
+3. [Part 3: The TorchForge-Monarch Connection](tutorials/zero-to-forge/3_Monarch_101): It's a 101 to Monarch and how TorchForge Talks to Monarch
Each part builds upon the next and the entire section can be consumed in roughly an hour - Grab a Chai and Enjoy!