Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ struct EmbeddingWithScaledGradientGradCUDAFunctor {
cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream()));
#endif

// When input has 0 elements, d_table is already correctly zeroed.
// Skip all kernel launches to avoid CUDA error(9) from GET_BLOCKS(0)==0.
if (K == 0) return;

if (FLAGS_embedding_deterministic == 1) {
funcs::LaunchEmbeddingGradDeterministicKernel<T, IdT>(
dev_ctx_, ids, d_output, d_table, N, D, K);
Expand Down
85 changes: 85 additions & 0 deletions test/legacy_test/test_embedding_scale_grad_by_freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,90 @@ def test_argument_error(self):
)


class TestEmbeddingScaleGradByFreqZeroSize(unittest.TestCase):
"""Test scale_grad_by_freq=True with 0-size input tensors.

When input has 0 elements (e.g. shape [0, 3]), the backward kernel
used to launch CountFreqKernel with GET_BLOCKS(0)=0 grid blocks,
causing CUDA error(9). This test ensures the fix works correctly.
"""

def setUp(self):
self.places = get_places()
self.weight_np = np.random.random((10, 4)).astype("float32")

def _check_zero_size_dygraph(self, x_shape, x_dtype, padding_idx):
for place in self.places:
paddle.disable_static(place)
x = paddle.zeros(x_shape, dtype=x_dtype)
w = paddle.to_tensor(self.weight_np)
w.stop_gradient = False

out = embedding(
x, w, padding_idx=padding_idx, scale_grad_by_freq=True
)
# Output shape: x_shape + [embed_dim]
expected_out_shape = [*x_shape, self.weight_np.shape[1]]
self.assertEqual(list(out.shape), expected_out_shape)

out.backward()
# Weight grad must have the same shape as weight and be all-zeros
self.assertIsNotNone(w.grad)
self.assertEqual(list(w.grad.shape), list(self.weight_np.shape))
np.testing.assert_array_equal(
w.grad.numpy(), np.zeros_like(self.weight_np)
)
paddle.enable_static()

def test_zero_first_dim_int32(self):
# shape [0, 3] int32, padding_idx=5
self._check_zero_size_dygraph([0, 3], 'int32', 5)

def test_zero_first_dim_int64(self):
# shape [0, 3] int64, padding_idx=-1
self._check_zero_size_dygraph([0, 3], 'int64', -1)

def test_zero_first_dim_only(self):
# shape [0] int64
self._check_zero_size_dygraph([0], 'int64', 2)

def test_zero_second_dim_int64(self):
# shape [2, 0] int64, padding_idx=2
self._check_zero_size_dygraph([2, 0], 'int64', 2)

def test_zero_second_dim_int32(self):
# shape [6, 0] int32, padding_idx=5
self._check_zero_size_dygraph([6, 0], 'int32', 5)

def test_zero_size_static(self):
"""Verify 0-size input works in static graph mode too."""
paddle.enable_static()
x_shape = [0, 3]
x_dtype = 'int64'
for place in self.places:
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data("x", x_shape, x_dtype)
w = paddle.static.data(
"w", self.weight_np.shape, self.weight_np.dtype
)
w.stop_gradient = False
out = embedding(x, w, padding_idx=-1, scale_grad_by_freq=True)
w_grad = paddle.static.gradients([out], w)
exe = paddle.static.Executor(place)
x_val = np.zeros(x_shape, dtype=x_dtype)
[out_val, grad_val] = exe.run(
feed={"x": x_val, "w": self.weight_np},
fetch_list=[out, w_grad],
return_numpy=True,
)
expected_out_shape = [*x_shape, self.weight_np.shape[1]]
self.assertEqual(list(out_val.shape), expected_out_shape)
np.testing.assert_array_equal(
grad_val, np.zeros_like(self.weight_np)
)


if __name__ == '__main__':
unittest.main()
Loading