Skip to content

Commit c77b197

Browse files
authored
Merge fixes from directml branch (#124)
Merges some of the recent changes from the directml branch: * Use compute queue for AMD devices (#102) * Register List Kernels for DML (#95) * Update DirectMLX to latest (#104) * Remove extra rows from test email (#106) * Fix DML's Select kernel for int64 (#113) * Fix list kernels and tensor array ops registration (#114) * Simplify CI scripts (#112) * Fix StridedSlice's input size coalescing (#115) * Disable int64 image test (#116) * Fix network share copy path (#117) * Pipeline should continue if a test job fails (#118) * Switch network share path to use build number instead of build ID * Add missing HostMemory int32 registrations for _Arg and _RetVal (#122) * Implement all the arithmetic Scatter and ResourceScatter operators (#121) * Register emulated kernel implementations for RandomStandardNormal and TruncatedNormal (#120)
1 parent 4d079e9 commit c77b197

File tree

93 files changed

+2393
-1875
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+2393
-1875
lines changed

tensorflow/core/common_runtime/dml/dml_device_state.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "dml_upload_heap.h"
2525
#include "dml_util.h"
2626
#include "tensorflow/core/platform/env.h"
27+
#include "tensorflow/core/util/env_var.h"
2728
#include "tensorflow/stream_executor/platform/default/dso_loader.h"
2829

2930
using Microsoft::WRL::ComPtr;
@@ -70,8 +71,20 @@ namespace tensorflow {
7071
ComPtr<IDMLDevice> dml_device;
7172
dml_device = CreateDmlDevice(d3d_device.Get(), dml_flags);
7273

74+
// Default to using compute queues for AMD since it seems to mitigate TDRs and
75+
// improve performance
76+
const bool use_compute_queue_default = adapter.VendorID() == VendorID::kAmd;
77+
78+
bool use_compute_queue;
79+
Status s = ReadBoolFromEnvVar("TF_DIRECTML_USE_COMPUTE_QUEUE",
80+
use_compute_queue_default, &use_compute_queue);
81+
82+
D3D12_COMMAND_LIST_TYPE queue_type = use_compute_queue
83+
? D3D12_COMMAND_LIST_TYPE_COMPUTE
84+
: D3D12_COMMAND_LIST_TYPE_DIRECT;
85+
7386
D3D12_COMMAND_QUEUE_DESC command_queue_desc = {};
74-
command_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
87+
command_queue_desc.Type = queue_type;
7588
command_queue_desc.Priority = D3D12_COMMAND_QUEUE_PRIORITY_NORMAL;
7689
command_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
7790
command_queue_desc.NodeMask = 0;

tensorflow/core/common_runtime/gpu/gpu_process_state.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
2727
#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
2828
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
29+
#include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h"
2930
#include "tensorflow/core/common_runtime/pool_allocator.h"
3031
#include "tensorflow/core/common_runtime/shared_counter.h"
3132
#include "tensorflow/core/framework/allocator.h"

tensorflow/core/kernels/BUILD

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2727,10 +2727,17 @@ tf_kernel_library(
27272727
tf_kernel_library(
27282728
name = "list_kernels",
27292729
srcs = ["list_kernels.cc"],
2730-
hdrs = ["list_kernels.h"],
2730+
hdrs = [
2731+
"list_kernels.h",
2732+
"tensor_array.h",
2733+
"aggregate_ops.h",
2734+
"split_lib.h"] + if_dml(["dml_tensor_array.h"]),
27312735
gpu_srcs = [
27322736
"list_kernels.cu.cc",
27332737
"list_kernels.h",
2738+
"tensor_array.h",
2739+
"aggregate_ops.h",
2740+
"split_lib.h",
27342741
],
27352742
deps = [
27362743
":concat_lib",
@@ -8094,6 +8101,7 @@ tf_kernel_library(
80948101
"dml_gather_op.cc",
80958102
"dml_gather_nd_op.cc",
80968103
"dml_scatter_nd_op.cc",
8104+
"dml_scatter_update_ops.cc",
80978105
"dml_tensor_scatter_ops.cc",
80988106
"dml_scan_ops.cc",
80998107
"dml_dynamic_stitch_op.cc",
@@ -8145,6 +8153,7 @@ tf_kernel_library(
81458153
"dml_kernel_wrapper.h",
81468154
"dml_ops_common.h",
81478155
"assign_op.h",
8156+
"random_op.h",
81488157
"stateless_random_ops.h",
81498158
"tensor_array.h",
81508159
"concat_lib.h",

tensorflow/core/kernels/dml_addn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class DmlAddNKernel : public DmlKernel {
6666
&identity_desc};
6767
Initialize(ctx, std::move(tensors), op_desc);
6868
} else {
69-
auto scope = dml::Scope(ctx->GetDmlDevice());
69+
auto scope = dml::Graph(ctx->GetDmlDevice());
7070
auto result = dml::InputTensor(scope, 0, inputs[0]);
7171

7272
for (uint32_t i = 1; i < inputs.size(); ++i) {

tensorflow/core/kernels/dml_batch_norm_ops.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ class DmlFusedBatchNormKernel : public DmlKernel {
321321
// the mean/variance tensors back to TF.
322322

323323
auto scope =
324-
dml::Scope(ctx->GetDmlDevice(), GetDmlXTensorPolicy(tensor_format));
324+
dml::Graph(ctx->GetDmlDevice(), GetDmlXTensorPolicy(tensor_format));
325325
auto x = dml::InputTensor(scope, 0, input_descs[0]);
326326
auto scale = dml::InputTensor(scope, 1, input_descs[1]);
327327
auto offset = dml::InputTensor(scope, 2, input_descs[2]);
@@ -451,7 +451,7 @@ class DmlFusedBatchNormKernel : public DmlKernel {
451451
auto output_descs = GetDmlTensorDescs(tensors.outputs);
452452

453453
auto scope =
454-
dml::Scope(ctx->GetDmlDevice(), GetDmlXTensorPolicy(tensor_format));
454+
dml::Graph(ctx->GetDmlDevice(), GetDmlXTensorPolicy(tensor_format));
455455
auto x = dml::InputTensor(scope, 0, input_descs[0]);
456456
auto mean = dml::InputTensor(scope, 1, input_descs[1]);
457457
auto variance = dml::InputTensor(scope, 2, input_descs[2]);
@@ -574,7 +574,7 @@ class DmlBatchNormWithGlobalNormalizationKernel : public DmlKernel {
574574
auto output_descs = GetDmlTensorDescs(tensors.outputs);
575575

576576
const uint32_t beta_index = scale_after_normalization ? 4 : 3;
577-
auto scope = dml::Scope(ctx->GetDmlDevice());
577+
auto scope = dml::Graph(ctx->GetDmlDevice());
578578
auto t = dml::InputTensor(scope, 0, input_descs[0]);
579579
auto m = dml::InputTensor(scope, 1, input_descs[1]);
580580
auto v = dml::InputTensor(scope, 2, input_descs[2]);
@@ -684,7 +684,7 @@ class DmlFusedBatchNormGradKernel : public DmlKernel {
684684
auto output_descs = GetDmlTensorDescs(tensors.outputs);
685685

686686
auto scope =
687-
dml::Scope(ctx->GetDmlDevice(), GetDmlXTensorPolicy(tensor_format));
687+
dml::Graph(ctx->GetDmlDevice(), GetDmlXTensorPolicy(tensor_format));
688688

689689
auto y_backprop =
690690
dml::InputTensor(scope, kYBackprop, input_descs[kYBackprop]);
@@ -885,7 +885,7 @@ class DmlBatchGlobalNormGradKernel : public DmlKernel {
885885
auto input_descs = GetDmlTensorDescs(tensors.inputs);
886886
auto output_descs = GetDmlTensorDescs(tensors.outputs);
887887

888-
auto scope = dml::Scope(ctx->GetDmlDevice());
888+
auto scope = dml::Graph(ctx->GetDmlDevice());
889889

890890
const uint32_t back_prop_index =
891891
scale_after_normalization ? kBackProp : kBackProp - 1;

tensorflow/core/kernels/dml_batch_to_space_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ class DmlBatchToSpaceKernel : public DmlKernel {
310310
return;
311311
}
312312

313-
auto scope = dml::Scope(ctx->GetDmlDevice());
313+
auto scope = dml::Graph(ctx->GetDmlDevice());
314314
auto input = dml::InputTensor(scope, 0, inputs[0]);
315315

316316
absl::Span<const int64> internal_block_sizes =
@@ -393,7 +393,7 @@ class DmlBatchToSpaceKernel : public DmlKernel {
393393
// Finally, slice the appropriate dimensions
394394
dml::TensorDesc::Dimensions slice_offsets(perm_reshaped_sizes.size());
395395
dml::TensorDesc::Dimensions slice_sizes = perm_reshaped_sizes;
396-
dml::TensorDesc::Dimensions slice_strides(perm_reshaped_sizes.size(), 1);
396+
absl::InlinedVector<int32_t, 4> slice_strides(perm_reshaped_sizes.size(), 1);
397397

398398
absl::Span<const int64> internal_crops = init_helper->GetInternalCrops();
399399

tensorflow/core/kernels/dml_check_numerics_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class DmlCheckNumericsKernel : public DmlKernel {
7373
tensors.outputs = {output};
7474

7575
auto inputs = GetDmlTensorDescs(tensors.inputs);
76-
auto scope = dml::Scope(ctx->GetDmlDevice());
76+
auto scope = dml::Graph(ctx->GetDmlDevice());
7777
auto input_tensor = dml::InputTensor(scope, 0, inputs[0]);
7878

7979
// Reduce doesn't support less than 32bit integer datatypes, so we need to

tensorflow/core/kernels/dml_conv_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ class DmlFusedConv2DKernel : public DmlKernel {
539539
auto input_descs = GetDmlTensorDescs(tensors.inputs);
540540
auto output_descs = GetDmlTensorDescs(tensors.outputs);
541541

542-
auto scope = dml::Scope(ctx->GetDmlDevice(), GetDmlXTensorPolicy(conv_params.data_format));
542+
auto scope = dml::Graph(ctx->GetDmlDevice(), GetDmlXTensorPolicy(conv_params.data_format));
543543
auto input = dml::InputTensor(scope, 0, input_descs[0]);
544544
auto filter = dml::InputTensor(scope, 1, input_descs[1]);
545545

tensorflow/core/kernels/dml_cwise_ops.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ class DmlCompositeBinaryKernel : public DmlKernel {
184184
auto inputs = GetDmlTensorDescs(tensors.inputs);
185185
auto outputs = GetDmlTensorDescs(tensors.outputs);
186186

187-
auto scope = dml::Scope(ctx->GetDmlDevice());
187+
auto scope = dml::Graph(ctx->GetDmlDevice());
188188
auto x = dml::InputTensor(scope, 0, inputs[0]);
189189
auto y = dml::InputTensor(scope, 1, inputs[1]);
190190

@@ -319,7 +319,7 @@ class DmlCompositeUnaryKernel : public DmlKernel {
319319
auto inputs = GetDmlTensorDescs(tensors.inputs);
320320
auto outputs = GetDmlTensorDescs(tensors.outputs);
321321

322-
auto scope = dml::Scope(ctx->GetDmlDevice());
322+
auto scope = dml::Graph(ctx->GetDmlDevice());
323323
auto x = dml::InputTensor(scope, 0, inputs[0]);
324324

325325
ExpressionFunctor expression;
@@ -803,7 +803,7 @@ class DmlBinaryWithZeroKernel : public DmlKernel {
803803
auto inputs = GetDmlTensorDescs(tensors.inputs);
804804
auto outputs = GetDmlTensorDescs(tensors.outputs);
805805

806-
auto scope = dml::Scope(ctx->GetDmlDevice());
806+
auto scope = dml::Graph(ctx->GetDmlDevice());
807807
auto x = dml::InputTensor(scope, 0, inputs[0]);
808808
auto y = dml::InputTensor(scope, 1, inputs[1]);
809809
auto zero = dml::ZeroTensor(scope, x.GetOutputDesc().dataType,
@@ -914,7 +914,7 @@ class DmlSquaredDifferenceKernel : public DmlKernel {
914914
auto inputs = GetDmlTensorDescs(tensors.inputs);
915915
auto outputs = GetDmlTensorDescs(tensors.outputs);
916916

917-
auto scope = dml::Scope(ctx->GetDmlDevice());
917+
auto scope = dml::Graph(ctx->GetDmlDevice());
918918
auto x = dml::InputTensor(scope, 0, inputs[0]);
919919
auto y = dml::InputTensor(scope, 1, inputs[1]);
920920
auto diff = x - y;
@@ -1043,7 +1043,7 @@ class DmlApproximateEqualKernel : public DmlKernel {
10431043
auto inputs = GetDmlTensorDescs(tensors.inputs);
10441044
auto outputs = GetDmlTensorDescs(tensors.outputs);
10451045

1046-
auto scope = dml::Scope(ctx->GetDmlDevice());
1046+
auto scope = dml::Graph(ctx->GetDmlDevice());
10471047
auto x = dml::InputTensor(scope, 0, inputs[0]);
10481048
auto y = dml::InputTensor(scope, 1, inputs[1]);
10491049

@@ -1196,7 +1196,7 @@ class DmlBitCountKernel : public DmlKernel {
11961196
// 2D so that we can reduce each adjacent pair of counts.
11971197
dml::TensorDesc::Dimensions double_sizes = {1, 1, num_elements, 2};
11981198

1199-
auto scope = dml::Scope(ctx->GetDmlDevice());
1199+
auto scope = dml::Graph(ctx->GetDmlDevice());
12001200
auto in_64_bit = dml::InputTensor(scope, 0, in_desc);
12011201
auto in_32_bit = dml::Reinterpret(in_64_bit, DML_TENSOR_DATA_TYPE_UINT32,
12021202
double_sizes, dml::NullOpt);

tensorflow/core/kernels/dml_data_format_dim_map.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class DmlDataFormaDimMapKernel : public DmlKernel {
8282
}
8383
}
8484

85-
auto scope = dml::Scope(ctx->GetDmlDevice());
85+
auto scope = dml::Graph(ctx->GetDmlDevice());
8686

8787
DmlKernelTensors tensors = GetTensorInfos(ctx, {});
8888
auto inputs = GetDmlTensorDescs(tensors.inputs);

0 commit comments

Comments
 (0)