Conversation
…attn, native, xformers; for inference only: sage; refactor ConfigMixin for arguments
…; remove custom block mask code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Context Parallel
References and reading material:
There are three steps to enabling context parallelism with any model:
apply_context_parallelfunction: This registers the necessary hooks to split and gather tensors at the right places in the model without having to manually modify the model code.attention_providercontext managerFor a quick example, refer to the inference example below.
The CP plan is a dictionary that maps the name of the module to a list of
CPInputorCPOutputobjects. The keys in the dictionary are the names of the internal modules in the model, and the values are dictionaries that map a parameter identifier (either as an argument index or keyword argument as used in the forward method) to aCPInputorCPOutputobject. TheCPInputobject specifies the input tensor to be split, and theCPOutputobject specifies the output tensor to be gathered.The
split_dimandgather_dimparameters specify the dimension along which to split or gather the tensor. When using CP with native scaled dot product attention from pytorch, the tensor shape is[B, N, S, D], so thesplit_dimandgather_dimparameters should be set to2as it is the sequence dimension.The
expected_dimsparameter is an optional parameter that is used for sanity checking if the tensor contains the expected number of dimensions.By default,
CPInput's are split in a pre-forward hook andCPOutput's are gathered in a post-forward hook. If you want to split the output of a module, you can set thesplit_outputparameter toTrue. This will split the output tensor in the post-forward hook instead of the pre-forward hook.Attention providers supported for training with CP:
flash,_native_cudnn,_native_efficient,_native_flashAttention providers supported for inference with CP:
flash,_native_cudnn,_native_efficient,_native_flashTraining
To enable training with context parallelism, you need to make sure a suitable CP plan is registered for the model you are using and launch training with
--cp_degree N, where N > 1. For models supported in finetrainers, this is internally done in the transformer metadata file. For custom models, make sure to pass theplanargument to theapply_context_parallelfunction.Currently supported models include: CogVideoX, CogView4, Flux, Wan 2.1. Support for more models and attention providers is in progress.
Inference
The following example shows how to run context parallel inference. For more examples and ready-to-use inference scripts, check out the examples/inference folder.
Example
Benchmarks
TODO: Will be updated in a future PR