Skip to content

Commit b9e2563

Browse files
eternalNightamaurya
authored andcommitted
DeepCompile: Fuse allgather and downcast (deepspeedai#7588)
With autocast enabled, a majority of weights are downcasted before being used in calculations. Today zero3_compile gathers the FP32 weights before they are downcasted. That is sub-optimal because FP32 weights consumes more bandwidth to allgather and takes more time to downcast. To reduce communication and downcast time, fuse allgather and downcast in the dc ops. The target type is now passed to allgather_param() and prefetch_params_fused() which will downcast the (partial) weights before launching allgathers. This corresponds to issue 1 of deepspeedai#7577. Tested with https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 (run with `deepspeed --num_gpus=N this_file.py -c -p -m 23` to collect torch and memory profiles, and with DINOV2_DEPTH = SIGLIP_DEPTH = 3, LLAMA2_DEPTH = 4 for faster compileation) on 5090 (which has limited inter-GPU bandwidth), time per step decreases from 438ms to 337ms and peak GPU memory usage from 9.5GB to 8.5GB. Profiles of a single step before this PR: <img width="1235" height="1029" alt="image" src="https://github.com/user-attachments/assets/d9fe5296-7731-4542-924b-421ff7415054" /> <img width="1466" height="616" alt="image" src="https://github.com/user-attachments/assets/aa192802-8633-4e36-b2c4-f28b1b432663" /> After this PR: <img width="1218" height="1006" alt="image" src="https://github.com/user-attachments/assets/18a0e09c-155b-4783-adb5-b4d36c5c3691" /> <img width="1537" height="559" alt="image" src="https://github.com/user-attachments/assets/16a2ca74-8a89-4db9-9b68-81844295c61b" /> This PR also reduces peak memory usage because the `fast_free_schedule()` today always arranges param allgathers and downcasts at the beginning of the graph. While the original FP32 params can be freed early, all FP16/BF16-casted params are kept in GPU memory at the beginning of the backward graph, leading to a higher peak in memory usage. P.S. Probably due to organization branch rule settings, I don't find anywhere to allow reviewers to modify the branch. So I'll update the branch per reviewers' comments and rebase if needed. Signed-off-by: Junjie Mao <[email protected]>
1 parent 5be0b0c commit b9e2563

File tree

9 files changed

+156
-39
lines changed

9 files changed

+156
-39
lines changed

csrc/compile/init.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
TORCH_LIBRARY(dc, m)
1212
{
13-
m.def("allgather_param(Tensor a, int graph_id, int id) -> Tensor");
14-
m.def("prefetch_params_fused(int graph_id, Tensor[] params, int[] ids) -> ()");
13+
m.def("allgather_param(Tensor a, int graph_id, int id, ScalarType? dtype = None) -> Tensor");
14+
m.def(
15+
"prefetch_params_fused(int graph_id, Tensor[] params, int[] ids,"
16+
" ScalarType[]? dtypes = None) -> ()");
1517
m.def("wait_allgather(Tensor(a) a, int graph_id, int id) -> Tensor(a)");
1618
m.def("release_param(Tensor(a) a, int graph_id, int id, int n_users) -> Tensor(a)");
1719
m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor");

csrc/compile/z3.cpp

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
6868
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
6969
{
7070
const DSParam& param = param_registry_->getParam(ds_id);
71-
const at::Tensor& ds_tensor = param.getDSTensor();
71+
at::Tensor ds_tensor = param.getDSTensor();
72+
73+
if (ds_tensor.scalar_type() != output_buf.scalar_type()) {
74+
at::cuda::CUDAStreamGuard guard(ag_stream_);
75+
ds_tensor = ds_tensor.to(output_buf.scalar_type(), true, true);
76+
}
7277

7378
if (symm_mem == nullptr) {
7479
// Fast path: assume uniform shard sizes (ZeRO-3 partitions are padded to uniform size)
@@ -110,6 +115,7 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
110115
}
111116

112117
at::Tensor allgatherParam(long ds_id,
118+
std::optional<at::ScalarType> dtype,
113119
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
114120
{
115121
const DSParam& param = param_registry_->getParam(ds_id);
@@ -118,11 +124,16 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
118124
const int64_t true_numel = static_cast<int64_t>(productDim(param.getShape()));
119125
const int64_t padded_per_rank = (true_numel + world_size - 1) / world_size;
120126
const int64_t padded_numel = static_cast<int64_t>(world_size) * padded_per_rank;
127+
at::ScalarType target_dtype = dtype ? dtype.value() : ds_tensor.scalar_type();
121128

122129
if (param_registry_->isValid(ds_id)) {
123130
// Return a view sliced to the true size with the original shape
131+
//
132+
// Persistent params are gathered in their original dtype which may
133+
// be different from the requested.
124134
auto base = param_registry_->getGatheredParam(ds_id);
125135
return base.flatten()
136+
.to(target_dtype)
126137
.index({torch::indexing::Slice(0, true_numel)})
127138
.view(param.getShape());
128139
}
@@ -134,7 +145,7 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
134145
}
135146
if (!output_buf.defined()) {
136147
at::cuda::CUDAStreamGuard guard(ag_stream_);
137-
output_buf = torch::empty({padded_numel}, ds_tensor.options());
148+
output_buf = torch::empty({padded_numel}, ds_tensor.options().dtype(target_dtype));
138149
}
139150

140151
assert(hasKey(ag_comp_done_events_, ds_id));
@@ -150,16 +161,20 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
150161
.view(param.getShape());
151162
}
152163

153-
void prefetchParamsFused(std::vector<int64_t> ds_ids,
164+
void prefetchParamsFused(const std::vector<long>& ds_ids,
165+
const std::optional<std::vector<at::ScalarType>> dtypes,
154166
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
155167
{
156-
std::vector<int64_t> invalid_ds_ids;
157-
for (const auto& ds_id : ds_ids) {
158-
if (!param_registry_->isValid(ds_id)) { invalid_ds_ids.push_back(ds_id); }
168+
std::vector<std::tuple<long, std::optional<at::ScalarType>>> invalid_params;
169+
for (int i = 0; i < ds_ids.size(); i++) {
170+
if (!param_registry_->isValid(ds_ids[i])) {
171+
auto dtype = dtypes ? dtypes.value()[i] : std::optional<at::ScalarType>();
172+
invalid_params.push_back(std::make_tuple(ds_ids[i], dtype));
173+
}
159174
}
160175

161176
std::unordered_map<long, at::Tensor> output_bufs;
162-
for (long ds_id : invalid_ds_ids) {
177+
for (const auto& [ds_id, dtype] : invalid_params) {
163178
const DSParam& param = param_registry_->getParam(ds_id);
164179
const at::Tensor& ds_tensor = param.getDSTensor();
165180
const int world_size = process_group_->getSize();
@@ -173,22 +188,26 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
173188
continue;
174189
}
175190
}
176-
output_bufs[ds_id] = torch::empty({padded_numel}, ds_tensor.options());
191+
auto target_dtype = dtype ? dtype.value() : ds_tensor.scalar_type();
192+
output_bufs[ds_id] =
193+
torch::empty({padded_numel}, ds_tensor.options().dtype(target_dtype));
177194
}
178195

179-
for (long ds_id : invalid_ds_ids) {
196+
for (const auto& [ds_id, _] : invalid_params) {
180197
ag_comp_done_events_[ds_id]->record();
181198
ag_comp_done_events_[ds_id]->block(ag_stream_);
182199
}
183200

184201
ncclGroupStart();
185-
for (long ds_id : invalid_ds_ids) {
202+
for (const auto& [ds_id, _] : invalid_params) {
186203
assert(hasKey(output_bufs, ds_id));
187204
launchAllGather(output_bufs.at(ds_id), ds_id, symm_mem);
188205
}
189206
ncclGroupEnd();
190207

191-
for (long ds_id : invalid_ds_ids) { ag_comm_done_events_[ds_id]->record(ag_stream_); }
208+
for (const auto& [ds_id, _] : invalid_params) {
209+
ag_comm_done_events_[ds_id]->record(ag_stream_);
210+
}
192211
}
193212

194213
void releaseParam(long ds_id, long n_users)
@@ -458,12 +477,15 @@ void register_z3_param(long ds_id,
458477
}
459478
}
460479

461-
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id)
480+
at::Tensor allgather_param(at::Tensor param_tensor,
481+
long graph_id,
482+
long ds_id,
483+
std::optional<at::ScalarType> dtype)
462484
{
463485
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
464486

465487
if (sync_before_allgather) { c10::cuda::device_synchronize(); }
466-
auto ret = executor->allgatherParam(ds_id, symm_mem);
488+
auto ret = executor->allgatherParam(ds_id, dtype, symm_mem);
467489
if (sync_after_allgather) { c10::cuda::device_synchronize(); }
468490
return ret;
469491
}
@@ -477,22 +499,25 @@ void set_persistent(long ds_id)
477499
for (auto& it : executors) {
478500
if (it.second->hasParam(ds_id)) {
479501
auto executor = getExecutor<Z3CustomOpExecutor>(it.first, executors);
480-
executor->allgatherParam(ds_id, symm_mem);
502+
auto dtype = param_registry->getParam(ds_id).getDtype();
503+
executor->allgatherParam(ds_id, dtype, symm_mem);
481504
}
482505
}
483506
}
484507

485508
void prefetch_params_fused(long graph_id,
486-
const std::vector<at::Tensor> params,
487-
const std::vector<long>& ds_ids)
509+
const std::vector<at::Tensor>& params,
510+
const std::vector<long>& ds_ids,
511+
const std::optional<std::vector<at::ScalarType>>& dtypes)
488512
{
489513
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
490-
executor->prefetchParamsFused(ds_ids, symm_mem);
514+
executor->prefetchParamsFused(ds_ids, dtypes, symm_mem);
491515
}
492516

493517
void prefetch_params_fused_meta(long graph_id,
494-
const std::vector<at::Tensor> params,
495-
const std::vector<long>& ds_ids)
518+
const std::vector<at::Tensor>& params,
519+
const std::vector<long>& ds_ids,
520+
const std::optional<std::vector<at::ScalarType>>& dtypes)
496521
{
497522
}
498523

@@ -518,11 +543,14 @@ void clear_all_gathered_params()
518543
}
519544
}
520545

521-
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id)
546+
at::Tensor allgather_param_meta(at::Tensor param_tensor,
547+
long graph_id,
548+
long ds_id,
549+
std::optional<at::ScalarType> dtype)
522550
{
523551
const DSParam& param = param_registry->getParam(ds_id);
524552
auto options = param.getDSTensor().options().device(c10::kMeta);
525-
at::Tensor output_buf = torch::empty(param.getShape(), options);
553+
at::Tensor output_buf = torch::empty(param.getShape(), options.dtype(dtype));
526554
return output_buf;
527555
}
528556

csrc/compile/z3.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,26 @@ void register_z3_param(long ds_id,
2121
at::Tensor ds_tensor,
2222
at::Tensor grad_buffer,
2323
bool persistent);
24-
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id);
24+
at::Tensor allgather_param(at::Tensor param_tensor,
25+
long graph_id,
26+
long ds_id,
27+
std::optional<at::ScalarType> dtype);
2528
void set_persistent(long ds_id);
2629
void prefetch_params_fused(long graph_id,
27-
const std::vector<at::Tensor> params,
28-
const std::vector<long>& ds_ids);
30+
const std::vector<at::Tensor>& params,
31+
const std::vector<long>& ds_ids,
32+
const std::optional<std::vector<at::ScalarType>>& dtypes);
2933
void prefetch_params_fused_meta(long graph_id,
30-
const std::vector<at::Tensor> params,
31-
const std::vector<long>& ds_ids);
34+
const std::vector<at::Tensor>& params,
35+
const std::vector<long>& ds_ids,
36+
const std::optional<std::vector<at::ScalarType>>& dtypes);
3237
// for profiling
3338
void invalidate_gathered_param(long ds_id);
3439
void clear_all_gathered_params();
35-
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id);
40+
at::Tensor allgather_param_meta(at::Tensor param_tensor,
41+
long graph_id,
42+
long ds_id,
43+
std::optional<at::ScalarType> dtype);
3644
at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users);
3745
at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users);
3846
at::Tensor wait_allgather(at::Tensor v, long graph_id, const long ds_id);

csrc/includes/deepcompile.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <c10/cuda/CUDAStream.h>
1919
#include <torch/csrc/cuda/nccl.h>
2020
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
21+
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
2122
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
2223

2324
#if __has_include(<torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>)
@@ -261,6 +262,7 @@ class DSParam {
261262
: id_(id),
262263
shape_(std::move(ds_shape)),
263264
ds_tensor_(ds_tensor),
265+
ds_dtype_(ds_tensor.scalar_type()),
264266
grad_buffer_(grad_buffer),
265267
partitioned_(partitioned),
266268
offset_(offset),
@@ -272,6 +274,7 @@ class DSParam {
272274

273275
long getId() const { return id_; }
274276
std::vector<int64_t> getShape() const { return shape_; }
277+
at::ScalarType getDtype() const { return ds_dtype_; }
275278
at::Tensor getDSTensor() const
276279
{
277280
// If the reload event exists and is complete, return the reloaded tensor (if defined)
@@ -343,6 +346,7 @@ class DSParam {
343346
private:
344347
long id_;
345348
std::vector<int64_t> shape_;
349+
at::ScalarType ds_dtype_;
346350
at::Tensor ds_tensor_;
347351
at::Tensor ds_reload_tensor_;
348352
at::Tensor grad_buffer_;

deepspeed/compile/fx.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# DeepSpeed Team
55

6-
from typing import Callable, Any, List
6+
from typing import Callable, Any, List, Dict
77
from collections import defaultdict
88

99
import torch
@@ -60,7 +60,8 @@ def add_args_process(graph: Graph,
6060
def add_postprocess(graph: Graph,
6161
node: Node,
6262
fn: Callable[..., Any],
63-
extra_args: List[int] = [],
63+
extra_args: List[Any] = [],
64+
extra_kwargs: Dict[str, Any] = {},
6465
name=None,
6566
meta={}) -> Node:
6667
# https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py
@@ -70,7 +71,7 @@ def add_postprocess(graph: Graph,
7071
args += (a, )
7172

7273
node_users = node.users.keys()
73-
new_node = graph.create_node('call_function', fn, args, {}, name=name)
74+
new_node = graph.create_node('call_function', fn, args, extra_kwargs, name=name)
7475
users = {}
7576
for u in node_users:
7677
if u != new_node:

deepspeed/compile/passes/zero3_compile.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
from torch.fx import Graph, Node, GraphModule
1212

13-
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses
13+
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses, is_cast_op
1414
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
1515
from ..profilers.graph_profile import ProfilingInterpreter
1616
from ..list_schedule import fast_free_schedule
@@ -21,14 +21,15 @@
2121
NAME = "zero3_compile"
2222

2323

24-
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
24+
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int, dtype: torch.dtype):
2525
new_ag_node = add_postprocess(graph,
2626
node,
2727
torch.ops.dc.allgather_param.default,
2828
extra_args=[graph_id, ds_id],
29+
extra_kwargs={"dtype": dtype},
2930
name=f"allgather_ds_param_{node.target}_{ds_id}",
3031
meta=_make_node_meta(node, ds_id, True))
31-
new_ag_node.meta["val"] = node.meta["val"]
32+
new_ag_node.meta["val"] = node.meta["val"].to(dtype)
3233

3334
# Set the previous node back to output
3435
# We don't want to change the output node to allgather
@@ -42,7 +43,7 @@ def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
4243
extra_args=[graph_id, ds_id],
4344
name=f"wait_allgather_ds_param__{node.target}_{ds_id}",
4445
meta=_make_node_meta(node, ds_id, False))
45-
new_wait_node.meta["val"] = node.meta["val"]
46+
new_wait_node.meta["val"] = new_ag_node.meta["val"]
4647

4748
return new_ag_node
4849

@@ -74,9 +75,30 @@ def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nod
7475
if len(pn.users) == 0:
7576
continue
7677

77-
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name])
78+
# If the only use of the parameter is a type-cast to a smaller type, fuse it with all-gather.
79+
fuse_typecast = False
80+
target_dtype = param_manager.params[pn.name].dtype
81+
if len([user for user in pn.users if user.op != "output"]) == 1:
82+
typecast_node = next(iter(pn.users))
83+
84+
is_cast, casted_dtype = is_cast_op(typecast_node)
85+
if is_cast and casted_dtype.itemsize < target_dtype.itemsize:
86+
fuse_typecast = True
87+
target_dtype = casted_dtype
88+
89+
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name], target_dtype)
90+
if fuse_typecast:
91+
users = node_to_uses[typecast_node]
92+
wait_node = typecast_node.args[0]
93+
for user in list(typecast_node.users.keys()):
94+
if user.op == "output":
95+
wait_node.meta["original_output_name"] = typecast_node.name
96+
user.replace_input_with(typecast_node, wait_node)
97+
graph.erase_node(typecast_node)
98+
else:
99+
users = node_to_uses[pn]
100+
78101
ds_id = param_manager.ds_ids[pn.name]
79-
users = node_to_uses[pn]
80102
for user in users:
81103
# release_param() only accepts tensors as its first argument. If
82104
# `user` is a tuple, we should release the param after any of

deepspeed/compile/profilers/graph_profile.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,15 @@ def run_node(self, n: torch.fx.Node) -> Any:
130130
assert isinstance(args, tuple)
131131
assert isinstance(kwargs, dict)
132132

133+
partitioned_params = {}
134+
133135
def rebuild_param_if_necessary(v):
134136
if hasattr(v, "ds_id"):
135137
v.all_gather(param_list=[v])
138+
if hasattr(v, "ds_target_dtype"):
139+
casted = v.to(v.ds_target_dtype)
140+
partitioned_params[id(casted)] = v
141+
return casted
136142
return v
137143

138144
args = map_aggregate(args, lambda x: rebuild_param_if_necessary(x))
@@ -191,6 +197,8 @@ def rebuild_param_if_necessary(v):
191197
tensor_size = _node_size(out)
192198

193199
def partition_param_if_necessary(v):
200+
if id(v) in partitioned_params:
201+
v = partitioned_params[id(v)]
194202
if hasattr(v, "ds_id") and not v.ds_persist:
195203
v.partition(param_list=[v], has_been_updated=False)
196204
return v
@@ -227,6 +235,8 @@ def partition_param_if_necessary(v):
227235
assert hasattr(out, "ds_id")
228236
if not out.ds_persist:
229237
self.nz3.invalidate_gathered_param(args[2])
238+
if "dtype" in n.kwargs:
239+
setattr(out, "ds_target_dtype", n.kwargs["dtype"])
230240
self.allgather_mem[out.ds_id] = n.meta["alloc_mem"]
231241

232242
return out

0 commit comments

Comments
 (0)