Skip to content

Commit e0a071d

Browse files
authored
[Embedding] Enable EmbeddingVariable python UT in cibuild. (#643)
1 parent 9acef18 commit e0a071d

File tree

4 files changed

+150
-138
lines changed

4 files changed

+150
-138
lines changed

cibuild/gpu-ut/gpu-python-ut.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ export TF_BUILD_BAZEL_TARGET="$TF_ALL_TARGETS "\
107107
"-//tensorflow/python/keras:convolutional_test "\
108108
"-//tensorflow/python/keras:lstm_v2_test "\
109109
"-//tensorflow/python/keras:lstm_v2_test_gpu "\
110-
"-//tensorflow/python:embedding_variable_ops_gpu_test "\
111-
"-//tensorflow/python:embedding_variable_ops_gpu_test_gpu "\
112110
"-//tensorflow/python/kernel_tests:normalize_op_test "\
113111
"-//tensorflow/python/kernel_tests:svd_op_test "
114112

tensorflow/core/kernels/kv_variable_ops.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ class KvVariableShapeOp : public OpKernel {
9595
Name("KvVariableShape").Device(DEVICE_##dev) \
9696
.TypeConstraint<type>("out_type") \
9797
.TypeConstraint<ktype>("Tkeys") \
98-
.TypeConstraint<vtype>("dtype"), \
98+
.TypeConstraint<vtype>("dtype") \
99+
.HostMemory("output"), \
99100
KvVariableShapeOp<type, ktype, vtype>);
100101
#define REGISTER_KERNELS_ALL(dev, type) \
101102
REGISTER_KERNELS(dev, int32, int32, type) \

0 commit comments

Comments
 (0)