Skip to content

Commit 01d13e2

Browse files
jiashuyshijieliu
andauthored
Fix issue related to empty batch (#271)
* Fix issue related to empty batch * Empty batch in prefetch * Empty batch in backward * debug * Fix IMA caused by not configuring shared memory size * Remove comments in update * fix return value of DynamicEmbeddingTable.update when empty batch * fix as AI suggests, AI is useful and helpful * avoid tensor.zeros --------- Co-authored-by: aleliu <aleliu@nvidia.com>
1 parent 38224a0 commit 01d13e2

File tree

5 files changed

+204
-43
lines changed

5 files changed

+204
-43
lines changed

corelib/dynamicemb/dynamicemb/key_value_table.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,9 @@ def update(
546546

547547
batch = keys.size(0)
548548

549+
if batch == 0:
550+
return 0, None, None
551+
549552
device = keys.device
550553
founds = torch.empty(batch, dtype=torch.bool, device=device)
551554
pointers = torch.empty(batch, dtype=torch.long, device=device)
@@ -1184,6 +1187,16 @@ def find_impl(
11841187

11851188
scores = self.create_scores(batch, device, input_scores)
11861189

1190+
if batch == 0:
1191+
return (
1192+
0,
1193+
torch.empty_like(unique_keys),
1194+
torch.empty(batch, dtype=torch.long, device=device),
1195+
torch.empty(batch, dtype=torch.uint64, device=device)
1196+
if scores is not None
1197+
else None,
1198+
)
1199+
11871200
score_args_lookup = [
11881201
ScoreArg(
11891202
name=self.score_policy.name,
@@ -1354,6 +1367,9 @@ def update(
13541367

13551368
batch = keys.size(0)
13561369

1370+
if batch == 0:
1371+
return 0, None, None
1372+
13571373
device = keys.device
13581374
founds = torch.empty(batch, dtype=torch.bool, device=device)
13591375
indices = torch.empty(batch, dtype=self.key_index_map.index_type, device=device)

corelib/dynamicemb/src/dynamic_emb_op.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,10 @@ void load_from_combined_table(std::optional<at::Tensor> dev_table,
898898
std::optional<at::Tensor> uvm_table,
899899
at::Tensor indices, at::Tensor output) {
900900

901+
int64_t num_total = indices.size(0);
902+
if (num_total == 0) {
903+
return;
904+
}
901905
int64_t stride = -1;
902906
int64_t dim = output.size(1);
903907
if ((not dev_table.has_value()) and (not uvm_table.has_value())) {
@@ -934,8 +938,6 @@ void load_from_combined_table(std::optional<at::Tensor> dev_table,
934938
auto val_type = get_data_type(output);
935939
auto index_type = get_data_type(indices);
936940

937-
int64_t num_total = indices.size(0);
938-
939941
constexpr int kWarpSize = 32;
940942
constexpr int MULTIPLIER = 4;
941943
constexpr int BLOCK_SIZE_VEC = 64;

corelib/dynamicemb/src/index_calculation.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,15 @@ void select(at::Tensor flags, at::Tensor inputs, at::Tensor outputs,
523523
auto num_select_iter_type =
524524
scalartype_to_datatype(num_selected.dtype().toScalarType());
525525

526+
if (num_total == 0) {
527+
DISPATCH_INTEGER_DATATYPE_FUNCTION(
528+
num_select_iter_type, NumSelectedIteratorT, [&] {
529+
DEMB_CUDA_CHECK(cudaMemsetAsync(
530+
reinterpret_cast<NumSelectedIteratorT *>(num_selected.data_ptr()), 0,
531+
sizeof(NumSelectedIteratorT), stream));
532+
});
533+
return;
534+
}
526535
DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, KeyType, [&] {
527536
DISPATCH_INTEGER_DATATYPE_FUNCTION(
528537
num_select_iter_type, NumSelectedIteratorT, [&] {
@@ -545,6 +554,15 @@ void select_index(at::Tensor flags, at::Tensor output_indices,
545554
auto num_select_iter_type =
546555
scalartype_to_datatype(num_selected.dtype().toScalarType());
547556

557+
if (num_total == 0) {
558+
DISPATCH_INTEGER_DATATYPE_FUNCTION(
559+
num_select_iter_type, NumSelectedIteratorT, [&] {
560+
DEMB_CUDA_CHECK(cudaMemsetAsync(
561+
reinterpret_cast<NumSelectedIteratorT *>(num_selected.data_ptr()), 0,
562+
sizeof(NumSelectedIteratorT), stream));
563+
});
564+
return;
565+
}
548566
DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, KeyType, [&] {
549567
DISPATCH_INTEGER_DATATYPE_FUNCTION(
550568
num_select_iter_type, NumSelectedIteratorT, [&] {

corelib/dynamicemb/src/optimizer.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ All rights reserved. # SPDX-License-Identifier: Apache-2.0
2020
#include "optimizer_kernel.cuh"
2121
#include "torch_utils.h"
2222
#include "utils.h"
23+
#include <functional>
2324

2425
void find_pointers(std::shared_ptr<dyn_emb::DynamicVariableBase> table,
2526
const size_t n, const at::Tensor keys, at::Tensor values,
@@ -545,7 +546,8 @@ void launch_update_kernel_for_combined_table(
545546
GradType *grads, WeightType *dev_table, WeightType *uvm_table,
546547
IndexType *indices, OptimizerType opt, int64_t const ev_nums,
547548
uint32_t const dim, int64_t const stride, int64_t const split_index,
548-
int device_id) {
549+
int device_id,
550+
std::function<float(int)> smem_size_f = [](int block_size) { return 0; }) {
549551
auto stream = at::cuda::getCurrentCUDAStream().stream();
550552
auto &device_prop = DeviceProp::getDeviceProp(device_id);
551553
if (dim % 4 == 0) {
@@ -574,7 +576,7 @@ void launch_update_kernel_for_combined_table(
574576

575577
auto kernel = update_with_index_kernel<GradType, WeightType, IndexType,
576578
OptimizerType>;
577-
kernel<<<grid_size, block_size, 0, stream>>>(
579+
kernel<<<grid_size, block_size, smem_size_f(block_size), stream>>>(
578580
ev_nums, dim, stride, split_index, grads, dev_table, uvm_table, indices,
579581
nullptr, opt);
580582
}
@@ -797,7 +799,8 @@ void rowwise_adagrad_for_combined_table(at::Tensor grads, at::Tensor indices,
797799

798800
launch_update_kernel_for_combined_table<g_t, w_t, i_t, decltype(opt)>(
799801
grad_ptr, dev_ptr, uvm_ptr, index_ptr, opt, ev_nums, dim, stride,
800-
split_index, device_id);
802+
split_index, device_id,
803+
[](int block_size) { return block_size * sizeof(float); });
801804
});
802805
});
803806
});

corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py

Lines changed: 160 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def test_forward_train_eval(opt_type, opt_params, caching, deterministic, PS):
404404

405405
"""
406406
For torchrec's adam optimizer, it will increment the optimizer_step in every forward,
407-
which will affect the weights update, pay attention to it or try to use `set_optimizer_step()`
407+
which will affect the weights update, pay attention to it or try to use `set_optimizer_step()`
408408
to control(not verified) it.
409409
"""
410410

@@ -444,6 +444,7 @@ def test_forward_train_eval(opt_type, opt_params, caching, deterministic, PS):
444444
[
445445
(True, DynamicEmbPoolingMode.NONE, [8, 8, 8]),
446446
(False, DynamicEmbPoolingMode.NONE, [16, 16, 16]),
447+
(False, DynamicEmbPoolingMode.NONE, [17, 17, 17]),
447448
(False, DynamicEmbPoolingMode.SUM, [128, 32, 16]),
448449
(False, DynamicEmbPoolingMode.MEAN, [4, 8, 16]),
449450
],
@@ -467,7 +468,10 @@ def test_backward(opt_type, opt_params, caching, pooling_mode, dims, determinist
467468
max_capacity = 2048
468469

469470
dyn_emb_table_options_list = []
471+
cmp_with_torchrec = True
470472
for dim in dims:
473+
if dim % 4 != 0:
474+
cmp_with_torchrec = False
471475
dyn_emb_table_options = DynamicEmbTableOptions(
472476
dim=dim,
473477
init_capacity=max_capacity,
@@ -492,49 +496,68 @@ def test_backward(opt_type, opt_params, caching, pooling_mode, dims, determinist
492496
**opt_params,
493497
)
494498
num_embs = [max_capacity // 2 for d in dims]
495-
stbe = create_split_table_batched_embedding(
496-
table_names,
497-
feature_table_map,
498-
OPTIM_TYPE[opt_type],
499-
opt_params,
500-
dims,
501-
num_embs,
502-
POOLING_MODE[pooling_mode],
503-
device,
504-
)
505-
init_embedding_tables(stbe, bdeb)
506-
"""
507-
feature number = 4, batch size = 2
508499

509-
f0 [0,1], [12],
510-
f1 [64,8], [12],
511-
f2 [15, 2, 7], [105],
512-
f3 [], [0]
513-
"""
514-
for i in range(10):
515-
indices = torch.tensor(
516-
[0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device
517-
).to(key_type)
518-
offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to(
519-
key_type
500+
if cmp_with_torchrec:
501+
stbe = create_split_table_batched_embedding(
502+
table_names,
503+
feature_table_map,
504+
OPTIM_TYPE[opt_type],
505+
opt_params,
506+
dims,
507+
num_embs,
508+
POOLING_MODE[pooling_mode],
509+
device,
520510
)
511+
init_embedding_tables(stbe, bdeb)
512+
"""
513+
feature number = 4, batch size = 2
514+
515+
f0 [0,1], [12],
516+
f1 [64,8], [12],
517+
f2 [15, 2, 7], [105],
518+
f3 [], [0]
519+
"""
520+
for i in range(10):
521+
indices = torch.tensor(
522+
[0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device
523+
).to(key_type)
524+
offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to(
525+
key_type
526+
)
521527

522-
embs_bdeb = bdeb(indices, offsets)
523-
embs_stbe = stbe(indices, offsets)
524-
525-
torch.cuda.synchronize()
526-
with torch.no_grad():
527-
torch.testing.assert_close(embs_bdeb, embs_stbe, rtol=1e-06, atol=1e-06)
528+
embs_bdeb = bdeb(indices, offsets)
529+
embs_stbe = stbe(indices, offsets)
530+
531+
torch.cuda.synchronize()
532+
with torch.no_grad():
533+
torch.testing.assert_close(embs_bdeb, embs_stbe, rtol=1e-06, atol=1e-06)
534+
535+
loss = embs_bdeb.mean()
536+
loss.backward()
537+
loss_stbe = embs_stbe.mean()
538+
loss_stbe.backward()
539+
540+
torch.cuda.synchronize()
541+
torch.testing.assert_close(loss, loss_stbe)
542+
543+
print(f"Passed iteration {i}")
544+
else:
545+
# This scenario will not test correctness, but rather test whether it functions correctly.
546+
for i in range(10):
547+
indices = torch.tensor(
548+
[0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device
549+
).to(key_type)
550+
offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to(
551+
key_type
552+
)
528553

529-
loss = embs_bdeb.mean()
530-
loss.backward()
531-
loss_stbe = embs_stbe.mean()
532-
loss_stbe.backward()
554+
embs_bdeb = bdeb(indices, offsets)
555+
loss = embs_bdeb.mean()
556+
loss.backward()
533557

534-
torch.cuda.synchronize()
535-
torch.testing.assert_close(loss, loss_stbe)
558+
torch.cuda.synchronize()
536559

537-
print(f"Passed iteration {i}")
560+
print(f"Passed iteration {i}")
538561

539562
if deterministic:
540563
del os.environ["DEMB_DETERMINISM_MODE"]
@@ -853,3 +876,102 @@ def test_deterministic_insert(opt_type, opt_params, caching, PS, iteration, batc
853876

854877
del os.environ["DEMB_DETERMINISM_MODE"]
855878
print("all check passed")
879+
880+
881+
@pytest.mark.parametrize(
882+
"opt_type,opt_params",
883+
[
884+
(EmbOptimType.SGD, {"learning_rate": 0.3}),
885+
(
886+
EmbOptimType.EXACT_ROWWISE_ADAGRAD,
887+
{
888+
"learning_rate": 0.3,
889+
"eps": 3e-5,
890+
},
891+
),
892+
],
893+
)
894+
@pytest.mark.parametrize("dim", [7, 8])
895+
@pytest.mark.parametrize("caching", [True, False])
896+
@pytest.mark.parametrize("deterministic", [True, False])
897+
@pytest.mark.parametrize("PS", [None])
898+
def test_empty_batch(opt_type, opt_params, dim, caching, deterministic, PS):
899+
print(
900+
f"step in test_forward_train_eval_empty_batch , opt_type = {opt_type} opt_params = {opt_params}"
901+
)
902+
903+
if deterministic:
904+
os.environ["DEMB_DETERMINISM_MODE"] = "ON"
905+
906+
assert torch.cuda.is_available()
907+
device_id = 0
908+
device = torch.device(f"cuda:{device_id}")
909+
910+
dims = [dim, dim, dim]
911+
table_names = ["table0", "table1", "table2"]
912+
key_type = torch.int64
913+
value_type = torch.float32
914+
915+
init_capacity = 1024
916+
max_capacity = 2048
917+
918+
dyn_emb_table_options_list = []
919+
for dim in dims:
920+
dyn_emb_table_options = DynamicEmbTableOptions(
921+
dim=dim,
922+
init_capacity=init_capacity,
923+
max_capacity=max_capacity,
924+
index_type=key_type,
925+
embedding_dtype=value_type,
926+
device_id=device_id,
927+
score_strategy=DynamicEmbScoreStrategy.TIMESTAMP,
928+
caching=caching,
929+
local_hbm_for_values=1024**3,
930+
external_storage=PS,
931+
)
932+
dyn_emb_table_options_list.append(dyn_emb_table_options)
933+
934+
bdebt = BatchedDynamicEmbeddingTablesV2(
935+
table_names=table_names,
936+
table_options=dyn_emb_table_options_list,
937+
feature_table_map=[0, 0, 1, 2],
938+
pooling_mode=DynamicEmbPoolingMode.NONE,
939+
optimizer=opt_type,
940+
use_index_dedup=True,
941+
**opt_params,
942+
)
943+
bdebt.enable_prefetch = True
944+
"""
945+
feature number = 4, batch size = 1
946+
947+
f0 [],
948+
f1 [],
949+
f2 [],
950+
f3 [],
951+
"""
952+
indices = torch.tensor([], dtype=key_type, device=device)
953+
offsets = torch.tensor([0, 0, 0, 0, 0], dtype=key_type, device=device)
954+
955+
pretch_stream = torch.cuda.Stream()
956+
forward_stream = torch.cuda.Stream()
957+
958+
if caching:
959+
with torch.cuda.stream(pretch_stream):
960+
bdebt.prefetch(indices, offsets, forward_stream)
961+
torch.cuda.synchronize()
962+
963+
with torch.cuda.stream(forward_stream):
964+
res = bdebt(indices, offsets)
965+
torch.cuda.synchronize()
966+
967+
res.mean().backward()
968+
969+
with torch.no_grad():
970+
bdebt.eval()
971+
bdebt(indices, offsets)
972+
torch.cuda.synchronize()
973+
974+
if deterministic:
975+
del os.environ["DEMB_DETERMINISM_MODE"]
976+
977+
print("all check passed")

0 commit comments

Comments
 (0)