Skip to content

Commit 1102150

Browse files
committed
support channels last inputs in xnnpack
1 parent 9d726e8 commit 1102150

File tree

42 files changed

+471
-107
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+471
-107
lines changed

.Package.swift/.watchman-cookie-madragna-mac-84335-443

Whitespace-only changes.

.ci/.watchman-cookie-madragna-mac-84335-443

Whitespace-only changes.

.github/.watchman-cookie-madragna-mac-84335-443

Whitespace-only changes.

.gitignore

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,19 @@ xcuserdata/
4040
.swiftpm/
4141
*.xcworkspace/
4242
*.xcframework/
43+
44+
# misc
45+
/.vscode/
46+
*.so
47+
*.dylib
48+
/cmake_wrapper.sh
49+
/data/
50+
/devtools/bundled_program/serialize/
51+
/exir/_serialize/program.fbs
52+
/exir/_serialize/scalar_type.fbs
53+
/include/
54+
/share/
55+
/version.py
56+
57+
# Android
58+
*.aar

.vscode/launch.json

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212
"args": [
1313
"--model_path=./add.pte",
1414
]
15-
}
15+
},
16+
{
17+
"name": "Debug python proj",
18+
"type": "debugpy",
19+
"request": "launch",
20+
"module": "unittest",
21+
"args": [
22+
"./backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py"
23+
]
24+
},
1625
]
1726
}

.vscode/settings.json

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,22 @@
6262
"algorithm": "cpp",
6363
"iterator": "cpp",
6464
"tuple": "cpp",
65-
"span": "cpp"
65+
"span": "cpp",
66+
"*.inc": "cpp",
67+
"alignedvector3": "cpp"
6668
},
6769
"C_Cpp.default.compilerPath": "/library/developer/commandlinetools/usr/bin/c++",
68-
"python.analysis.typeCheckingMode": "off"
70+
"python.analysis.typeCheckingMode": "off",
71+
"python.testing.unittestArgs": [
72+
"-v",
73+
"-s",
74+
"./backends",
75+
"-p",
76+
"test_*.py"
77+
],
78+
"python.testing.pytestEnabled": true,
79+
"python.testing.unittestEnabled": false,
80+
"python.testing.pytestArgs": [
81+
"."
82+
]
6983
}

backends/.watchman-cookie-madragna-mac-84335-443

Whitespace-only changes.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
13+
namespace vkcompute {
14+
15+
DynamicDispatchNode::DynamicDispatchNode(
16+
ComputeGraph& graph,
17+
const PickShaderFn& pick_shader_fn,
18+
const PickGlobalFn& pick_global_wg_fn,
19+
const PickLocalFn& pick_local_wg_fn,
20+
const std::vector<ArgGroup>& args,
21+
const vkapi::ParamsBindList& params,
22+
const std::vector<PushConstantDataInfo>& push_constants,
23+
const vkapi::SpecVarList& spec_vars,
24+
const std::vector<ValueRef>& resize_args,
25+
const ResizeFunction& resize_fn)
26+
: DispatchNode(
27+
graph,
28+
pick_shader_fn(&graph, args, resize_args),
29+
pick_global_wg_fn(&graph, args, resize_args),
30+
pick_local_wg_fn(&graph, args, resize_args),
31+
args,
32+
params,
33+
push_constants,
34+
spec_vars,
35+
resize_args,
36+
resize_fn),
37+
pick_shader_fn_(pick_shader_fn),
38+
pick_global_wg_fn_(pick_global_wg_fn),
39+
pick_local_wg_fn_(pick_local_wg_fn) {}
40+
41+
void DynamicDispatchNode::encode(ComputeGraph* graph) {
42+
shader_ = pick_shader_fn_(graph, args_, resize_args_);
43+
global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_);
44+
local_workgroup_size_ =
45+
utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_));
46+
DispatchNode::encode(graph);
47+
}
48+
49+
} // namespace vkcompute
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/containers/PushConstantData.h>
14+
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
17+
18+
namespace vkcompute {
19+
20+
class ComputeGraph;
21+
22+
/*
23+
* Represents a single shader execution op in a ML model.
24+
*/
25+
class DynamicDispatchNode final : public DispatchNode {
26+
friend class ComputeGraph;
27+
28+
public:
29+
using PickShaderFn = const std::function<vkapi::ShaderInfo(
30+
ComputeGraph*,
31+
const std::vector<ArgGroup>&,
32+
const std::vector<ValueRef>&)>;
33+
using PickGlobalFn = const std::function<utils::uvec3(
34+
ComputeGraph*,
35+
const std::vector<ArgGroup>&,
36+
const std::vector<ValueRef>&)>;
37+
using PickLocalFn = const std::function<utils::uvec3(
38+
ComputeGraph*,
39+
const std::vector<ArgGroup>&,
40+
const std::vector<ValueRef>&)>;
41+
42+
explicit DynamicDispatchNode(
43+
ComputeGraph& graph,
44+
const PickShaderFn& pick_shader_fn,
45+
const PickGlobalFn& pick_global_wg_fn,
46+
const PickLocalFn& pick_local_wg_fn,
47+
const std::vector<ArgGroup>& args,
48+
const vkapi::ParamsBindList& params,
49+
const std::vector<PushConstantDataInfo>& push_constants,
50+
const vkapi::SpecVarList& spec_vars,
51+
const std::vector<ValueRef>& resize_args,
52+
const ResizeFunction& resize_fn = nullptr);
53+
54+
~DynamicDispatchNode() override = default;
55+
56+
void encode(ComputeGraph* graph) override;
57+
58+
protected:
59+
const PickShaderFn pick_shader_fn_;
60+
const PickGlobalFn pick_global_wg_fn_;
61+
const PickLocalFn pick_local_wg_fn_;
62+
63+
public:
64+
operator bool() const {
65+
return shader_;
66+
}
67+
};
68+
69+
} // namespace vkcompute
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
layout(std430) buffer;
14+
15+
${layout_declare_tensor(0, "w", "t_out", "float", "texture3d")}
16+
${layout_declare_tensor(1, "r", "t_in1", "float", "texture3d")}
17+
${layout_declare_tensor(2, "r", "t_in2", "float", "texture3d")}
18+
19+
layout(push_constant) uniform restrict Block {
20+
ivec4 out_sizes;
21+
ivec4 in1_sizes;
22+
ivec4 in2_sizes;
23+
};
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
void main() {
28+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
29+
30+
if (any(greaterThanEqual(pos, out_sizes.xyz))) {
31+
return;
32+
}
33+
34+
35+
vec4 out_texel = vec4(0.0);
36+
for (int row = 0; row < in1_sizes.y; ++row) {
37+
ivec3 in_pos = ivec3(pos.x, row, pos.z);
38+
vec4 in1_texel = texelFetch(t_in1, in_pos, 0);
39+
vec4 in2_texel = texelFetch(t_in2, in_pos, 0);
40+
41+
out_texel += in1_texel * in2_texel;
42+
}
43+
44+
imageStore(t_out, pos, out_texel + ${OFFSET});
45+
}

0 commit comments

Comments
 (0)