IR.py defines the Pipe class, which is the main intermediate representation used in PiPPy. This intermediate representation consists of a restricted fx.GraphModule. In the top level fx.Graph representation, the IR is limited to only the following node types:
placeholderandoutputnodes to specify overall pipeline inputs and outputs, respectivelycall_modulenodes to represents calls into the pipeline stages- A single call to the
_losssubmodule can also be present to represent the loss computation for training
- A single call to the
call_functionwith a target ofoperator.getitem, for unpacking tuple outputs from pipeline stagescall_functionwith a target ofIR.stage_backwardfor the purpose of modeling the backward computation of each pipeline stage. This is described later in the section about IR for the backward pass.call_functionwith a target oftorch.add, emitted solely for accumulating gradients of values that have multiple uses in the backward pass.
The top-level fx.Graph gives us 1) a topological ordering of pipeline stages and 2) the data dependencies between these pipeline stages. Note that this is more general than existing pipeline APIs, as it supports arbitrary non-local (i.e. skip) connections between stages.
We can create IR from existing PyTorch modules using one of several front-ends, exposed as static methods on Pipe. Pipe.from_sequential takes as argument an instance of torch.nn.Sequential and returns a Pipe instance that represents the trivial feed-forward nature of that sequential. For example:
mods = [torch.nn.Linear(512, 512) for _ in range(5)]
mods += [mods[0]]
seq = torch.nn.Sequential(*mods)
seq_pipe = Pipe.from_sequential(seq)
print(seq_pipe.split_gm)
"""
GraphModule(
(submod_0): Linear(in_features=512, out_features=512, bias=True)
(submod_1): Linear(in_features=512, out_features=512, bias=True)
(submod_2): Linear(in_features=512, out_features=512, bias=True)
(submod_3): Linear(in_features=512, out_features=512, bias=True)
(submod_4): Linear(in_features=512, out_features=512, bias=True)
(submod_5): Linear(in_features=512, out_features=512, bias=True)
)
def forward(self, input):
input_1 = input
_0 = self.submod_0(input_1); input_1 = None
_1 = self.submod_1(_0); _0 = None
_2 = self.submod_2(_1); _1 = None
_3 = self.submod_3(_2); _2 = None
_4 = self.submod_4(_3); _3 = None
_5 = self.submod_5(_4); _4 = None
return _5
"""
Similarly, we can use pipeline to use torch.fx tracing to convert an arbitrary nn.Module instance to this form. For example:
class ExampleCode(torch.nn.Module):
def __init__(self):
super().__init__()
self.mm_param = torch.nn.Parameter(torch.randn(512, 512))
self.mm_param2 = torch.nn.Parameter(torch.randn(512, 512))
self.lin = torch.nn.Linear(512, 512)
def forward(self, x):
x = torch.mm(x, self.mm_param)
skip_connection = x
x = torch.relu(x)
pipe_split()
x = torch.mm(x, self.mm_param)
x = self.lin(x)
pipe_split()
x = torch.relu(x)
x = x + skip_connection
x = torch.mm(x, self.mm_param2)
x = self.lin(x)
return x
ec = ExampleCode()
ec(torch.randn(50, 512))
ec_pipe = pipeline(ec, MultiUseParameterConfig.TRANSMIT)
print(ec_pipe.split_gm)
"""
GraphModule(
(submod_0): GraphModule()
(submod_1): GraphModule(
(lin): Linear(in_features=512, out_features=512, bias=True)
)
(submod_2): GraphModule(
(lin): Linear(in_features=512, out_features=512, bias=True)
)
)
def forward(self, x):
submod_0 = self.submod_0(x); x = None
getitem_2 = submod_0[2]
getitem = submod_0[0]
getitem_1 = submod_0[1]; submod_0 = None
submod_1 = self.submod_1(getitem, getitem_2); getitem = getitem_2 = None
submod_2 = self.submod_2(submod_1, getitem_1); submod_1 = getitem_1 = None
return submod_2
"""
There are a few things to note about the above example:
- We use
IR.pipe_splitto explicitly demarcate within the code where we want pipeline boundaries to be.from_tracingwill collect all data dependencies across these calls topipe_splitand emit corresponding data dependencies in the pipeline graph.- Note that
IR.PipeSplitWrapperandIR.annotate_split_pointscan be used to unintrusively specify split points at the beginning or end of execution of any Module in the module hierarchy
- Note that
- Note the
skip_connectionvalue in the original program.from_tracingwill correctly detect the usage of this value in non-adjacent pipeline stages and emit a connection in the top-level graph to forward this dependency from stage 0 to 2. - Notice that
self.mm_paramis used both in pipeline stage 0 and pipeline stage 1. Since we have specifiedMultiUseParameterConfig.TRANSMITas themulti_use_param_specargument tofrom_tracing, the system will emit code that will keepmm_paramresident on stage 0 and transmit that value for use within stage 1.multi_use_param_speccan also be specified as a dictionary mapping parameter qualified names to aMultiUseParameterConfigvalue (one ofTRANSMITorREPLICATE) or it can be left asNoneto specify the default behavior (TRANSMIT) for all shared parameters. We will discuss replication in the following section.
Multi-use parameters can also be replicated. That is, each pipeline stage that uses a replicated parameter will have its own copy of the parameter and the system will record information about this replication such that the runtime can insert the proper synchronization operations upon update of these parameters. For example, let us rerun the above example with multi_use_param_spec=MultiUseParameterConfig.REPLICATE:
ec_pipe_replicated = pipeline(ec, MultiUseParameterConfig.REPLICATE)
print(ec_pipe_replicated.replicated_params)
"""
[{'submod_0': '__mm_param', 'submod_1': '__mm_param'},
{'submod_1': 'lin.weight', 'submod_2': 'lin.weight'},
{'submod_1': 'lin.bias', 'submod_2': 'lin.bias'}]
"""
Note that the Pipe instance has an attribute replicated_params, which is a record of all of the parameters that are replicated across pipeline stages. This object is a list of dictionaries. Each dictionary represents a single value that has been replicated across stages. The keys of the dictionary are the qualified name of the pipeline stage submodules that hold copies of this parameter, and the values are the qualified name of the parameter itself within those pipeline stage modules. Note that not only do we see mm_param in the above example, but we also see parameter replication from the usage of the self.lin module in multiple pipeline stages. self.lin is a "leaf module" in torch.fx parlance, and since we cannot see into the implementation of a leaf module, we automatically replicate leaf module parameters (note that we could hypothetically emit code to fetch the parameters values from leaf modules and transmit them to use sites, but that will require further development work).
torch.fxtracing imposes limitations on the classes of programs that can be captured (as described in Limitations of Symbolic Tracing). Thus, this limits the applicability of the above-described system. However, we can think of several ways to address this:- Convert tracing into a "just-in-time" system, where program capture happens on each invocation, and is specialized to certain parameters like shapes flowing throughout the program. By using ephemeral traces that are transmitted to each worker on each invocation, we can address the limitations of e.g. dynamic control flow. However, we will need to figure out the semantics of parameter placement in this scenario, as moving those around on each invocation will likely be suboptimal
- Note that the obvious drawback here is that program capture happens on each invocation, thus leading to overhead in the latency of the process. Projects like LazyTensor have been trying to address this via dispatch optimizations and pipelining scalar program computation with scalar program computation, but these approaches have proven difficult to achieve in a performant way. We can think of ways to make it so that large subsections of the program are seen as "basic blocks", or "builtin-instructions", analogous to the instructions of a processor. For example, the user could certify that a module and its recursive callees are straightline code, and we can dispatch directly to the pipelined version of those modules. We can also explore techniques of speculative execution + cancellation, speculating that we will enter a pipelined block when we see the prefix of its instructions and stall, cancel, and reissue if our speculation is wrong.
- We can formulate pipeline parallelism as a program (program counter, call stack, live heap content) that migrates through multiple devices. Thus, rather than defining semantics for program capture and analysis, we simply need to define runtime semantics for migrating a running coroutine between devices. This may be difficult to implement in existing languages (Python). This could be implemented in TorchScript, but then that would require the program to be admissible to TorchScript's program capture limitations. Maybe we should just make a new language.
- Convert tracing into a "just-in-time" system, where program capture happens on each invocation, and is specialized to certain parameters like shapes flowing throughout the program. By using ephemeral traces that are transmitted to each worker on each invocation, we can address the limitations of e.g. dynamic control flow. However, we will need to figure out the semantics of parameter placement in this scenario, as moving those around on each invocation will likely be suboptimal
The above ideas may be candidates for research investigation.
Pipe.from_sequential and pipeline also take a loss_fn argument to specify the loss computation in the training scenario. loss_fn can be an nn.Module instance or a free function. The module/function should take two positional arguments: the output of the feedforward computation and the target values. An example of using this API and the IR it produces can be seen here:
ec_pipe_with_loss = pipeline(ec, loss_fn=mse_loss)
print(ec_pipe_with_loss.split_gm)
"""
GraphModule(
(submod_0): GraphModule()
(submod_1): GraphModule(
(lin): Linear(in_features=512, out_features=512, bias=True)
)
(submod_2): GraphModule(
(lin): Linear(in_features=512, out_features=512, bias=True)
)
(_loss): MSELoss()
)
def forward(self, x, target):
submod_0 = self.submod_0(x)
getitem_2 = submod_0[2]
getitem = submod_0[0]
getitem_1 = submod_0[1]
submod_1 = self.submod_1(getitem, getitem_2)
submod_2 = self.submod_2(submod_1, getitem_1)
_loss = self._loss(submod_2, target)
stage_backward = __main___stage_backward(stage_output = _loss, output_grads = None, input_values = [submod_2, target]); target = None
getitem_3 = stage_backward[0]
getitem_4 = stage_backward[1]; stage_backward = None
stage_backward_1 = __main___stage_backward(stage_output = submod_2, output_grads = getitem_3, input_values = [submod_1, getitem_1]); submod_2 = getitem_3 = getitem_1 = None
getitem_5 = stage_backward_1[0]
getitem_6 = stage_backward_1[1]; stage_backward_1 = None
stage_backward_2 = __main___stage_backward(stage_output = submod_1, output_grads = getitem_5, input_values = [getitem, getitem_2]); submod_1 = getitem_5 = getitem = getitem_2 = None
getitem_7 = stage_backward_2[0]
getitem_8 = stage_backward_2[1]; stage_backward_2 = None
stage_backward_3 = __main___stage_backward(stage_output = submod_0, output_grads = [getitem_7, getitem_6, getitem_8], input_values = [x]); submod_0 = getitem_7 = getitem_6 = getitem_8 = x = None
return _loss
"""
Note the following:
- When
loss_fnis specified, an additional positional input (target) is added to the signature of the model. During training, the value representing the target label used in loss computation should be passed in as this argument. - The loss value is returned, allowing for e.g. logging of loss values during training.
- A simple symbolic automatic differentiation process emits code for computing the gradients of the model training process in a pipelined way. This is described below.
When thinking about how to implement backwards() in a pipeline parallel scenario, there are two considerations:
- In PyTorch autograd,
requires_graddictates whether operations record values/functions in the autograd tape. However, this does not compel execution of the backward for those operations. backwardis only run when a correspondingbackward()call later in the program initiates gradient computation.
Given this, we should think about how these semantics could map onto pipeline parallel execution, especially given that we would like to arbitrarily schedule the forward/backward jobs in the computation (such as in 1f1b scheduling). There are basically two options here:
- Emulate the “autograd tracing + just-in-time gradient computation” of Eager mode PyTorch. This may be a bit difficult to do in conjunction with schedules, as the dynamic nature of this type of autograd makes it more difficult to reason about if/when/what is run in autograd or to reference specific forward/backward executions in a schedule
- Model
backwardahead-of-time by emitting stages into the IR. This allows us to schedule all stages uniformly (we don’t need to essentially replicate the autograd machinery in the runtime) and allows us to know what backward stages will be run ahead of time and reference them in the scheduling system.
We elect to implement option (2). We implement this by doing a reverse iteration over the nodes of the Pipe module and applying the following rules for each type of node:
-
call_module. For module calls, we want to compute the gradient of each tensor input or module parameter with
requires_gradwith respect to the gradient of each tensor output thatrequires_grad. We can emit a call to a functionstage_backward(output_vals, dout)that is a wrapper overautograd.backward(or autograd.grad). This wrapper handles unpacking/packing collection type inputs/outputs (like a pytree) and delegates to autograd.backward to compute and accumulates gradient values into the .grad attribute for each input and parameter value of this pipeline stage.stage_backwardreturns the gradients of the input values of the pipeline stage.- TODO: For gradient checkpointing, we should wrap the forward() invocation of the forward module with torch.utils.checkpoint. Then we can compute gradients in the same manner (TODO: is this right?)
- TODO: is it good enough to dynamically figure out which tensor inputs require grad?
- TODO: if the input tensors are values sent over the wire from the remote, do they have any attached grad_fn? If so, do we need to block the gradient somehow? Can we detach?
- TODO: how to transmit saved output tensors on the same device without putting them through RPC? i.e. the data dependencies from the forward to the backward phase should just be passing a tensor reference, not going over RPC
- Maybe just special-case this in the executor?
- TODO: Zach mentioned that it is not always necessary to save the whole output tensor for computing the gradient. e.g. gradient of matmul does not require the output in the formulae for gradient of its inputs. Is there a way to call autograd.backward and only pass in the grad_fns and not the output tensors themselves? ask alban
-
call_function+operator.getitem. This is used solely for the purpose of indexing into tuple outputs of stages if a stage has multiple outputs. In the backwards, the corresponding operation should be to rebuild the collection type for the purpose of passing it tostage_backward. We need to lazily build these collection types as we iterate in reverse order over the program -
placeholder - TODO: should we return gradients of pipeline inputs? Does anyone need this?
PipelineDriver.py contains the implementation for a single-driver multiple-follower runtime that interprets the abstract IR described above. The classes contained within this file are the following:
PipelineDriverBaseis the base class that specifies the interface for pipeline runtimes. This abstract class must override and its abstract methods implemented to specify a given pipeline parallel schedule/execution semantics.PipelineDriverBasespecifies the following methods:__init__takes various configuration parameters for this schedule.pipeis the pipelined program to run.world_sizeandall_ranksspecify the execution environment using parameters that match thetorch.distributed.rpcrank and world size concepts.all_ranksallows you to specify arbitrary ranks to run the pipeline on, otherwiserange(world_size)is assumed.single_lossallows you to specify that losses from all micro-batches should be computed via a single application of the loss function, rather than individual applications for each micro-batch. (Note thatsingle_loss=Trueis not currently implemented)._debug_mask_minibatchesspecifies to send masked versions of the mini-batch through instead of micro-batch slices--this can be used for more stable numerical testing (see [A Note About Correctness Testing])run(self, chunks : int, *args, **kwargs). This is the main entrypoint for running a mini-batch through the pipeline.chunksis the number of chunks (micro-batches) to split the minibatch into.*argsand**kwargsare the input values to the model code (Tensor values should have exactly one batch dimension along which to divide).batch-dimsspecifies--for eachTensorinput in the same order they appear inargs-- what the batch dimensions are for each Tensor (ifNone, the 0th dimension is assumed to be the batch dimension for each tensor).MicroBatchSplitTensoris a data structure to represent an input (an element ofargs) that has been split into chunks.
The following implementations of PipelineDriverBase exist:
PipelineDriverFillDrainimplements a GPipe-style fill-drain schedule.runimplements the run interface fromPipelineDriverBaseand will execute a mini-batch of examples by splitting it up into micro-batches, pumping those through the ranks in a fill-drain pipeline manner, and returning the result (or loss value) and computing gradients (if thePipeinstance has backwards specified).
PipelineDriver1F1Bimplements a 1F1B schedule.
Implementation details:
RemoteInterpretersplits an input mini-batch into micro-batches and interprets the top-levelPipegraph, issuinginvokecalls to the associatedPipeStageExecutorsto orchestrate execution of the program in a pipelined fashion.RemoteInterpreterexposes a customrun_untilmethod, which allows you to run the given interpreter until a given predicate is true for the next node to be executed. This allows you to implement schedules by executing subsets of the full pipeline and interleaving the computation from different micro-batches onto thePipeStageExecutor.- At the end of execution of the
forward,PipelineDriverwill synchronize the parameters that were marked as replicated in the replicated_params attribute onPipe. Essentially, a replicated parameter is duplicated during pipelining. This duplication splits the usages of the parameter into multiple disjoint copies, one per stage. Each of these copies of the parameter will receive gradient signal coming from the operations in the stage in which it resides. To get the full gradient signal for the parameter as written in the original model, we need to sum up the gradients from each copy._sync_replicated_paramscurrently does this by downloading the copies to the master rank, summing up the values, and broadcasting the summed value to each of the stages thathas a copy of the parameter. After that, the optimizer can be applied correctly.
In general, the inputs and outputs to a given PyTorch model/program can be any combination of primitive or collection types (including torch.Tensor). Pipelining in DL training relies upon the following two assumptions:
- The program can be split across instructions (i.e. program text) and
- The program can be split and parallelized across input data, run in parallel, and joined together at the output data
Regarding the second requirement, we need to define both an API and a splitting semantics to a) confirm the program is splittable this way as well as b) actually do the splitting at runtime.
We'll need two components:
- An API specifying how to decompose input values into constituent chunked (or replicated) components. This could be implemented with something like pytree
- An inference algorithm to ensure that the program can indeed be split in this way. Some cases where this breaks include cases like BatchNorm, which has a reduction operation across the batch dimension
- An API similar to (1) that specifies how the single output value should be reconstructed from the chunked outputs.
We can examine two types of schedules to extract out the requirements for a general, programmable system for issuing and scheduling computation:
- Synchronous Fill-Drain (aka GPipe[1])
- Synchronous 1F1B (aka PipeDream-Flush) (from PipeDream[2] + Megatron 2 paper[3])
We can further examine what needs to happen at different stages of the runtime for each strategy. The two stages we'll examine are:
- Partitioning the work and distributing (microbatch, phase) pairs to each of the pipeline stages (in the form of
WorkItems) - Scheduling
WorkItemsat runtime, subject to the policy specified by a schedule
| Fill-Drain (GPipe) | Synchronous 1F1B | |
|---|---|---|
| WorkItem Issue | All forward WorkItems, then full mini-batch loss, then all backward WorkItems | All forward WorkItems, then micro-batch loss, then all backward WorkItems |
| Online Schedule | Consume and process WorkItems in order. Note that we my want to make the loss computation configurable in terms of whether it happens over the full mini-batch or per-micro-batch. | Consume forward WorkItems in order until steady state, then alternate forward and backward until drain, then consume backward WorkItems in-order. Note that the last stage must execute the loss computation per-micro-batch |
Given the above, we should implement extension points for both the RemoteInterpreter class that issues WorkItems to each stage, as well as the PipeStageExecutor class that handles execution of WorkItems at runtime.
Idea: compiler/processor split
- Compiler: orders instructions from each micro-batch into some order for consumption by the processor
- Processor: configurable execution engine. Mainly can be configured for in-order or out-of-order execution.
We can organize the schedules along the compiler/processor split:
- Fill-drain: Compiler orders all forward chunks, loss, then all backward. Could either be an in-order processor or an out-of-order processor. In the case of OOO, compiler will emit barrier instruction
- 1F1B: Compiler orders chunks in 1f1b order. In-order processor, strict about ordering
- Dynamic: Compiler orders chunks in any order. Out-of-order processor with registers/resource limits.
Note that micro-batch splitting and reconstruction is not guaranteed to be bitwise-equivalent to running the same program on the full batch (see here). See also exps/split_example.py, which demonstrates this when constant USE_WHOLE_BATCH is set to False. A proposed way to get around this is, when testing for correctness, is to run the full batch through the network for each micro-batch invocation and slice out the results from the full batch that correspond to each micro-batch, then cat those partial results together. This is demonstrated when USE_WHOLE_BATCH is True. This should guarantee numerical equivalence during testing while still exercising the micro-batch pipelining machinery.
During the training loop, we should focus on a few important elements:
- Data loader
forward- This is already trivially handled by the current implementation.
Pipehandles micro-batch computation of the model in theforwardexecution
- This is already trivially handled by the current implementation.
- Loss
- There are a few alternatives for API design here:
- We can preserve the training loop as-is and make it so that invoking a loss computation on the output
of the forward() computation issues jobs on the last stage to compute the loss. We could use something
similar to DistributedLoss. An open question is how to represent the output of forward() in such a way
that it can represent the forward() output for all micro-batches. This might look like a data structure
that is a list of RRefs backed by async futures on the pipeline stages. Then, if we intercept computation
on this loss object, we would issue
WorkItems for each operation for each micro-batch. However, this seems to degenerate down to a full-on tracing/single-coordinator system - We can make it so that the pipeline API takes the loss as a function argument and essentially encapsulates the whole training loop. This is much easier, probably less flaky (in terms of not needing to build a whole tracing mechanism), but is not super Pythonic. It may facilitate implementing async pipeline parallelism in the future
- We can preserve the training loop as-is and make it so that invoking a loss computation on the output
of the forward() computation issues jobs on the last stage to compute the loss. We could use something
similar to DistributedLoss. An open question is how to represent the output of forward() in such a way
that it can represent the forward() output for all micro-batches. This might look like a data structure
that is a list of RRefs backed by async futures on the pipeline stages. Then, if we intercept computation
on this loss object, we would issue
- There are a few alternatives for API design here:
backward- There are similar considerations for
backwardas there are forloss.backwardis an invocation on the scalar loss value that will need to schedule jobs in the backend.
- There are similar considerations for
- Optimizer
- Requirements here are similar to loss and backwards, but the optimizer step happens only once for the whole mini-batch, so it may be the case that this can literally be the same as the normal optimizer (potentially using DistributedOptimizer).

