diff --git a/demo.ipynb b/demo.ipynb new file mode 100644 index 000000000..459fa0654 --- /dev/null +++ b/demo.ipynb @@ -0,0 +1,677 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ed6165f4-0038-40d3-bf65-419fcf61af24", + "metadata": {}, + "source": [ + "# Intro to Forge\n", + "\n", + "Forge is a PyTorch-native framework designed for rapid experimentation and large-scale training of Reinforcement Learning (RL) algorithms with Large Language Models (LLMs). It's designed to:\n", + "- Express RL algorithms as naturally as psuedocode, while scaling seamlessly across clusters\n", + "- Support varying degrees of asynchrony - from fully synchronous/on-policy, to fully asynchronous/off-policy training\n", + "- Separate infrastructural concerns from algorithmic implementation\n", + "- Bias towards composable, reusable components that can be mixed and matched for different RL approaches\n", + "\n", + "Forge is built on top of proven components:\n", + "- **[Monarch](https://github.com/meta-pytorch/monarch)** - PyTorch-native single-controller framework\n", + "- **[torchtitan](https://github.com/pytorch/torchtitan)** - PyTorch-native large-scale LLM training platform\n", + "- **[vLLM](https://github.com/vllm-project/vllm)** - A high-throughput, memory efficient inference and serving engine for LLMs\n", + "\n", + "Our mission is to accelerate innovation in reinforcement learning by empowering researchers and developers to explore new RL algorithms and infrastructure techniques. Whether you're designing novel training methods or optimizing distributed systems, Forge provides a foundation to build upon." + ] + }, + { + "cell_type": "markdown", + "id": "24de9912-ed10-4729-9616-2f85bbf64e43", + "metadata": {}, + "source": [ + "## Brief Intro to Monarch\n", + "Before diving into Forge, we need to first establish the foundation. Forge is built on top of Monarch, PyTorch's native single-controller framework for distributed execution.\n", + "\n", + "Forge builds many of its abstractions on top of Monarch, so it's worth introducing a few of its key concepts first. The following sections borrow from Monarch's Getting Started Guide (not public yet).\n", + "\n", + "### Defining an Actor\n", + "At its core, Monarch uses [actors](https://en.wikipedia.org/wiki/Actor_model) as a way to create multi-machine programs. Actors are Python objects that expose a number of endpoint functions. These functions can be called by other actors in the system and their responses gathered asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "334ff5d6-bf9f-4b53-a04e-b488083a8101", + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "from monarch.actor import Actor, endpoint, this_proc\n", + "\n", + "class Counter(Actor):\n", + " def __init__(self, initial_value: int):\n", + " self.value = initial_value\n", + "\n", + " @endpoint\n", + " def increment(self) -> None:\n", + " self.value += 1\n", + "\n", + " @endpoint\n", + " def get_value(self) -> int:\n", + " return self.value\n" + ] + }, + { + "cell_type": "markdown", + "id": "8b2815f7-ad6f-4928-b262-033c9b5cb847", + "metadata": {}, + "source": [ + "The decorator `@endpoint` specifies functions of the Actor that can be called remotely from other actors.\n", + "\n", + "### Spawning An Actor In The Local Process\n", + "\n", + "We spawn actors in the current running process like so:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9e22453-d877-4334-8d80-b3bc1de85455", + "metadata": {}, + "outputs": [], + "source": [ + "counter: Counter = this_proc().spawn(\"counter\", Counter, initial_value=0)" + ] + }, + { + "cell_type": "markdown", + "id": "b8a9aa84-2ef4-4961-9173-4fee73c065c5", + "metadata": {}, + "source": [ + "`this_proc()` is a handle to a process, giving us direct control over where an actor runs. Monarch is very literal about where things are run, so that code can be written in the most efficient way. \n", + "\n", + "### Sending A Simple Message\n", + "Once an actor is spawned, we can send messages to the actor:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d20ada2-9084-4439-bd9b-e95881cf7009", + "metadata": {}, + "outputs": [], + "source": [ + "from monarch.actor import Future\n", + "\n", + "fut: Future[int] = counter.get_value.call_one()\n", + "\n", + "value = await fut\n", + "\n", + "print(f\"Counter value: {value}\")" + ] + }, + { + "cell_type": "markdown", + "id": "6bd15594-380a-44a1-977c-c536b6ae3a9c", + "metadata": {}, + "source": [ + "Here we invoke the `get_value` message, returning 0, the current value of the Counter. `call_one` here is referred to as an \"adverb\" because it modified how results of the endpoint are handled. `call_one` just invokes a single actor and gets its value.\n", + "\n", + "Notice that the return value is a `Future[int]` - the message is sent asynchronously, letting the sender do other things before it needs the reply. We can `await` on the result.\n", + "\n", + "### Multiple Actors at Once\n", + "Monarch scales to thousands of machines because of its ability to broadcast a single message to many actors at once, rather than send many point-to-point messages.\n", + "\n", + "Monarch expresses broadcasted communication by organizing actors into a `Mesh` - a multi-dimensional container with named dimensions. An example cluster may have dimensions `{\"hosts\": 32, \"gpus\": 8}`. To create a mesh of actors, we'll create a mesh of processes and spawn an actor on them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4349a0c4-161c-4d68-9d97-202b650e344c", + "metadata": {}, + "outputs": [], + "source": [ + "from monarch.actor import ProcMesh, this_host\n", + "\n", + "procs: ProcMesh = this_host().spawn_procs(per_host={\"gpus\": 8})\n", + "counters: Counter = procs.spawn(\"counters\", Counter, 0)" + ] + }, + { + "cell_type": "markdown", + "id": "202a8126-2e45-4074-a2ba-87e12d2f06dc", + "metadata": {}, + "source": [ + "### Broadcasting Messages\n", + "Now messages can be sent to all actors in the mesh:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1395ba9-c47a-4902-bea2-320bd1144fd2", + "metadata": {}, + "outputs": [], + "source": [ + "await counters.increment.call()" + ] + }, + { + "cell_type": "markdown", + "id": "35e2f5da-1840-49f5-b187-47be6d4b185d", + "metadata": {}, + "source": [ + "Note that here, we use the `call()` adverb. You will see other adverbs in Monarch code as well:\n", + "- `call_one()` - invoke a single actor and get its value (what we saw before)\n", + "- `choose()` - randomly invoke a single actor and gets its value from within a mesh of actors\n", + "- `call()` - invoke all actors in an actor mesh, and return its values as a `ValueMesh` \n", + "- `broadcast()` - fire-and-forget all actors in an actor mesh\n", + "- `stream()` - invoke all actors and return its values as an iterator\n", + "\n", + "There's much more to cover with Monarch, but these foundations provide the building blocks needed to understand how Forge creates its RL-specific services on top of this distributed actor system." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34fbc2ba-620c-4170-8ad6-97e78ac3f0b0", + "metadata": {}, + "outputs": [], + "source": [ + "await procs.stop()" + ] + }, + { + "cell_type": "markdown", + "id": "924b1514-b10c-4a4e-bc11-222f3f2a2933", + "metadata": {}, + "source": [ + "## Forge Services\n", + "Forge introduces *Services* - a higher-level abstraction built on top of Monarch actors. Services handle all the operational complexity of managing distributed ActorMeshes: spawning actors across nodes, fault tolerance, load balancing, and intelligent routing.\n", + "\n", + "### Creating a Forge Service\n", + "Creating a Forge service requires minimal changes to actors like we've created above. You replace your Actor base with a ForgeActor, and change how you spawn the actor:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "248530c1-3cbc-44b1-aa80-aab335280870", + "metadata": {}, + "outputs": [], + "source": [ + "from forge.controller import ForgeActor\n", + "from forge.controller.service import ServiceConfig, spawn_service, shutdown_service\n", + "from monarch.actor import endpoint\n", + "\n", + "\n", + "class ForgeCounter(ForgeActor):\n", + " def __init__(self, initial_value: int):\n", + " self.value = initial_value\n", + "\n", + " @endpoint\n", + " def increment(self) -> int:\n", + " self.value += 1\n", + " return self.value\n", + "\n", + " @endpoint\n", + " def get_value(self) -> int:\n", + " return self.value\n", + "\n", + " @endpoint\n", + " async def reset(self):\n", + " self.value = 0\n", + "\n", + " @endpoint\n", + " def fail(self):\n", + " raise RuntimeError(\"I was asked to fail\")\n", + "\n", + "\n", + "counter_service = await spawn_service(\n", + " ServiceConfig(procs_per_replica=1, num_replicas=4),\n", + " ForgeCounter,\n", + " initial_value=0)" + ] + }, + { + "cell_type": "markdown", + "id": "8f905101-9a69-4532-8e88-711c83ed1570", + "metadata": {}, + "source": [ + "Here, we've created a simple \"Counter service\" with 4 replicas, each running on 1 process.\n", + "\n", + "### Service Adverbs: Operating at the Replica Level\n", + "Services introduce new adverbs that operate at the replica level, not individual actors. Since replicas can be spawned across multiple processes, each replica is essentially an ActorMesh in Monarch terms:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07a02047-f29a-47bf-a077-6bd5c2976cf1", + "metadata": {}, + "outputs": [], + "source": [ + "# choose() - routes to one replica (load balanced, and which may contain multiple actors)\n", + "await counter_service.increment.choose()\n", + "\n", + "# call() - runs on ALL replicas\n", + "results = await counter_service.increment.call()\n", + "\n", + "print(results)" + ] + }, + { + "cell_type": "markdown", + "id": "0a08b5a5-0180-4e0d-80b4-cc0620c510b4", + "metadata": {}, + "source": [ + "Key distinction:\n", + "- Monarch's `choose()` picks a single actor from an `ActorMesh`\n", + "- Forge's `choose()` picks a single replica (which could be an entire `ActorMesh` of actors)\n", + "\n", + "This abstraction lets you think in terms of logical compute units (replicas) rather than individual processes." + ] + }, + { + "cell_type": "markdown", + "id": "4ac9c7dc-42b5-4e10-a0ae-34916fc40360", + "metadata": {}, + "source": [ + "### Load Balancing in Action\n", + "Services handle load balancing:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11181a5a-5692-4163-a96a-151d9b075454", + "metadata": {}, + "outputs": [], + "source": [ + "await counter_service.reset.call()\n", + "print(\"Increment on different replicas:\")\n", + "for i in range(8):\n", + " result = await counter_service.increment.choose()\n", + " print(f\"Call {i}: Counter value = {result}\")" + ] + }, + { + "cell_type": "markdown", + "id": "64c9334f-4db7-4cde-8ef2-d8e66a26d21b", + "metadata": {}, + "source": [ + "Each replica maintains its own state, and requests are distributed evenly.\n", + "\n", + "### Sticky Session for Stateful Operations\n", + "When you need to interact with the same replica consistently:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "453ec452-8a71-44ca-8afe-4393270356a1", + "metadata": {}, + "outputs": [], + "source": [ + "# Use sticky sessions to stay on one replica\n", + "async with counter_service.session():\n", + " await counter_service.reset.choose()\n", + " print(await counter_service.increment.choose())\n", + " print(await counter_service.increment.choose())\n", + " print(await counter_service.increment.choose())\n", + " \n", + " final_value = await counter_service.get_value.choose()\n", + " print(f\"Final value on this replica: {final_value}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b44eded4-d987-4620-9cf4-bbe3205102b1", + "metadata": {}, + "source": [ + "Sticky sessions can be crucial for efficiency, i.e. whenever you need to maintain KV cache state across multiple turns.\n", + "\n", + "### Automatic Fault Tolerance\n", + "Services automatically handle failures:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48173946-c50f-4649-9064-ffaa4e116a7d", + "metadata": {}, + "outputs": [], + "source": [ + "# This will fail on one replica\n", + "try:\n", + " await counter_service.fail.choose()\n", + "except ValueError:\n", + " print(\"Replica failed, but service continues...\")\n", + "\n", + "# Subsequent calls automatically route around the failed replica\n", + "result = await counter_service.increment.choose()\n", + "print(f\"Still working: {result}\")\n", + "\n", + "# The failed replica will be automatically recovered" + ] + }, + { + "cell_type": "markdown", + "id": "18e296b5-f1d5-4c3e-9f58-9842b78b1cc7", + "metadata": {}, + "source": [ + "Behind the scenes: Forge marks unhealthy replicas, routes traffic away from them, and spawns replacements automatically.\n", + "\n", + "### Why This Matters for RL\n", + "These service abstractions solve critical RL infrastructure challenges:\n", + "\n", + "1. Load balancing: Distribute rollouts across policy replicas efficiently\n", + "2. Sticky sessions: Maintain state between rollouts and their associated replicas, i.e. KV cache consistency\n", + "3. Fault tolerance: Keep training running even when individual nodes fail\n", + "4. Operational simplicity: No infrastructure code in your RL algorithms\n", + "\n", + "### Performance: Control Plane vs Data Plane\n", + "One important area we haven't covered yet is how Forge separates the **control plane** (service coordination) from the **data plane** (tensor transfers). You might reasonably wonder about performance implications if all data flows through TCP in a service-based architecture.\n", + "\n", + "We're actively developing **TorchStore** - our solution for high-performance tensor storage and retrieval over high-bandwidth interconnects like RDMA. This separation ensures that while Forge services handle coordination and routing, heavy tensor operations bypass the service layer entirely.\n", + "\n", + "*TorchStore will be covered in detail before our official release.*\n", + "\n", + "\n", + "Next, we'll see how these building blocks enable elegant RL algorithm expression." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44344bab-edcd-4851-98e4-59d8f9a4f3d8", + "metadata": {}, + "outputs": [], + "source": [ + "await shutdown_service(counter_service)" + ] + }, + { + "cell_type": "markdown", + "id": "e3a941c8-a744-4666-a1cb-67849b58e80a", + "metadata": {}, + "source": [ + "## Forge-Native Services\n", + "Now let's see the power of this abstraction in action. Forge provides service implementations of common RL components that constitute typical training workloads, for instance:\n", + "- Policy: Responsible for generating trajectories and responses\n", + "- Trainer: Responsible for updating policy weights\n", + "- Reference Model: Responsible for computing reference logprobs to prevent policy drift\n", + "- Reward: Responsible for evaluating trajectory quality\n", + "- Dataset: Responsible for serving prompts and target answers\n", + "- Advantage: Responsible for computing advantages from trajectories\n", + "\n", + "\n", + "### Building a Synchronous RL Workflow\n", + "Let's demonstrate by building a simple on-policy RL workflow. We'll start by spinning up multiple services using a small Qwen model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c622ed81-8fbc-4bd6-95f8-11a5df682711", + "metadata": {}, + "outputs": [], + "source": [ + "from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig\n", + "from forge.actors.replay_buffer import ReplayBuffer\n", + "from forge.controller.actor import ForgeActor\n", + "from forge.controller.service import ServiceConfig, shutdown_service, spawn_service\n", + "from forge.data.rewards import MathReward, ThinkingReward\n", + "from apps.grpo.main import Trainer, RewardActor, ComputeAdvantages, RefModel, DatasetActor, Group, Episode\n", + "\n", + "\n", + "model = \"Qwen/Qwen3-1.7B\"\n", + "group_size = 1\n", + "\n", + "(\n", + " dataloader,\n", + " policy,\n", + " trainer,\n", + " replay_buffer,\n", + " compute_advantages,\n", + " ref_model,\n", + " reward_actor,\n", + ") = await asyncio.gather(\n", + " spawn_service(\n", + " ServiceConfig(procs_per_replica=1, num_replicas=1),\n", + " DatasetActor,\n", + " path=\"openai/gsm8k\",\n", + " config_name=\"main\",\n", + " split=\"train\",\n", + " streaming=True,\n", + " ),\n", + " spawn_service(\n", + " ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),\n", + " Policy,\n", + " config=PolicyConfig(\n", + " worker_params=WorkerConfig(model=model),\n", + " sampling_params=SamplingOverrides(\n", + " num_samples=group_size, max_tokens=16\n", + " ),\n", + " ),\n", + " ),\n", + " spawn_service(\n", + " ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),\n", + " Trainer,\n", + " learning_rate=1e-5,\n", + " beta=0.1,\n", + " model_name=model,\n", + " ),\n", + " spawn_service(\n", + " ServiceConfig(procs_per_replica=1, num_replicas=1),\n", + " ReplayBuffer,\n", + " batch_size=2,\n", + " max_policy_age=1,\n", + " ),\n", + " spawn_service(\n", + " ServiceConfig(procs_per_replica=1, num_replicas=1),\n", + " ComputeAdvantages,\n", + " gamma=0.99,\n", + " lambda_=0.95,\n", + " ),\n", + " spawn_service(\n", + " ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),\n", + " RefModel,\n", + " model_name=model,\n", + " ),\n", + " spawn_service(\n", + " ServiceConfig(procs_per_replica=1, num_replicas=1),\n", + " RewardActor,\n", + " reward_functions=[MathReward(), ThinkingReward()],\n", + " ))\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "0f6e86de-0ba8-4a79-84ec-78bf99a96616", + "metadata": {}, + "source": [ + "What's happening here:\n", + "- Each service is spawned independently with its own configuration\n", + "- GPU services like the `policy`, `trainer`, and `ref_model` get GPU allocation\n", + "- All services run concurrently and can be scaled independently\n", + "- The same model is used for policy and reference, but they're separate services\n", + "\n", + "Notice what we're not doing:\n", + "- Managing CUDA placement across services\n", + "- Coordinating distributed training setup\n", + "- Handling inter-service communication protocols\n", + "- Writing fault tolerance and retry logic\n", + "\n", + "All of this is handled by our Service abstraction.\n", + "\n", + "Let's check that the services indeed work as expected:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8334fdf0-0afa-48c2-bb42-faa46e8be36a", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"What is 3 + 5?\"\n", + "responses = await policy.generate.choose(prompt=prompt)\n", + "\n", + "print(responses)" + ] + }, + { + "cell_type": "markdown", + "id": "22dc3db0-2402-4344-ae2a-658ea46d9792", + "metadata": {}, + "source": [ + "The response quality isn't great (it's only a 1.7B model!), but the service infrastructure is working perfectly.\n", + "\n", + "## Building the RL Training Loop\n", + "### The Role of RL in Post-Training\n", + "One way to interpret the role of RL in post-training is to align a base pre-trained model towards hard-to-define targets. The goal is \"sampling\" the right data that we think will best align the model.\n", + "\n", + "This is the role of \"rollouts\" - creating the dataset used to update our policy. Rather than training on a static dataset, RL dynamically generates training data by having the current policy interact with the environment.\n", + "\n", + "Let's build a step-by-step synchronous training loop to see how these services work together. The basic RL cycle is:\n", + "\n", + "1. Collect Experience: Get a prompt, generate a response, evaluate the reward\n", + "2. Compute Rewards: Calculate how much better/worse each action was than expected\n", + "3. Store Experience: Add the episode to our replay buffer\n", + "4. Sample & Train: Get a batch of experiences and update the policy\n", + "5. Repeat: Continue this cycle to improve the policy\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "298547e7-d185-48d1-a4b7-d122f760b707", + "metadata": {}, + "outputs": [], + "source": [ + "from apps.grpo.main import Episode, Group\n", + "\n", + "\n", + "async def simple_rl_step():\n", + " \"\"\"Execute one complete RL training step\"\"\"\n", + " \n", + " # ===== Generate a rollout =====\n", + " sample = await dataloader.__next__.choose()\n", + " prompt, target = sample[\"question\"], sample[\"answer\"]\n", + " \n", + " print(f\"Prompt: {prompt}\")\n", + " print(f\"Target: {target}\")\n", + " \n", + " actions = await policy.generate.choose(prompt=prompt)\n", + " print(f\"Policy response: {actions[0].text}\")\n", + " \n", + " ref_logprobs = await ref_model.forward.choose(actions[0].token_ids) \n", + " reward = await reward_actor.evaluate_response.choose(\n", + " prompt=prompt, \n", + " response=actions[0].text, \n", + " target=target\n", + " )\n", + " print(f\"Reward: {reward}\")\n", + " \n", + " episode = Episode(\n", + " episode_id=0,\n", + " prompt=prompt,\n", + " target=target, \n", + " policy_version=0,\n", + " )\n", + " \n", + " episode.add_group(Group(\n", + " response=actions[0].text,\n", + " ref_logprobs=ref_logprobs,\n", + " reward=reward,\n", + " ))\n", + " \n", + " advantages = await compute_advantages.__call__.choose(episode.groups)\n", + " episode.groups[0].advantage = advantages[0]\n", + " print(f\"Advantage: {advantages[0]}\") \n", + " await replay_buffer.add.choose(episode)\n", + " print(\"Episode stored in replay buffer\")\n", + " \n", + " # ===== Train on the batch ===== \n", + " batch = await replay_buffer.sample.choose(curr_policy_version=0)\n", + " if batch is not None:\n", + " print(\"Training on batch...\")\n", + " training_result = await trainer.train_step.choose(batch)\n", + " loss = training_result.get(\"loss\", 0.0)\n", + " print(f\"Training loss: {loss}\")\n", + " return loss\n", + " else:\n", + " print(\"Not enough data in buffer yet\")\n", + " return None\n", + "\n", + "\n", + "for step in range(10):\n", + " print(f\"\\n--- RL Step {step + 1} ---\")\n", + " loss = await simple_rl_step()\n", + " if loss:\n", + " print(f\"Step {step + 1} complete, loss: {loss:.4f}\")\n", + " else:\n", + " print(f\"Step {step + 1} complete, building buffer...\")" + ] + }, + { + "cell_type": "markdown", + "id": "d03fb780-54ef-445e-85ef-5708ceb8e5be", + "metadata": {}, + "source": [ + "**Note**: The model responses aren't great (1.7B parameters + 16 token limit = not exactly o1!), but notice how clean the RL algorithm code is. The power of these abstractions is that you can focus on the algorithm logic while all the distributed coordination happens automatically behind the scenes.\n", + "\n", + "TODO - conclude this with trainer->inference weight sync and demonstrate how the response changes\n", + "\n", + "## Next Steps\n", + "This simple example demonstrates the core concepts, but for a production-ready implementation, check out our full GRPO (Group Relative Policy Optimization) example at apps/grpo/main.py. It includes the complete async training loops, proper logging, model weight synchronization, and all the optimizations needed for large-scale RL training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff8b2f2e-98f6-41cd-865e-14cfa23c510f", + "metadata": {}, + "outputs": [], + "source": [ + "await asyncio.gather(\n", + " shutdown_service(policy),\n", + " shutdown_service(trainer),\n", + " shutdown_service(replay_buffer),\n", + " shutdown_service(dataloader),\n", + " shutdown_service(compute_advantages),\n", + " shutdown_service(ref_model),\n", + " shutdown_service(reward_actor))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0be0a295-5d84-497d-a1b1-0bef53d9c753", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}