Skip to content

Commit 7ada55f

Browse files
committed
Merge remote-tracking branch 'origin/main' into java-unknown-type
2 parents d0e2800 + dfd3dbe commit 7ada55f

34 files changed

+424
-305
lines changed

.github/scripts/extract_benchmark_results.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def transform(
341341
benchmark_results: List,
342342
benchmark_config: Dict[str, str],
343343
job_name: str,
344+
job_report: Any = {},
344345
) -> List:
345346
"""
346347
Transform the benchmark results into the format writable into the benchmark database
@@ -361,6 +362,7 @@ def transform(
361362
# Just keep a copy of the benchmark config here
362363
"benchmark_config": json.dumps(benchmark_config),
363364
"job_conclusion": "SUCCESS",
365+
"job_arn": job_report.get("arn", ""),
364366
},
365367
},
366368
"model": {
@@ -446,6 +448,7 @@ def transform_failure_record(
446448
"app_type": app_type,
447449
"job_conclusion": result,
448450
"failure_type": level,
451+
"job_arn": report.get("arn", ""),
449452
"job_report": json.dumps(report),
450453
},
451454
},
@@ -512,6 +515,7 @@ def get_benchmark_config(
512515
def extract_benchmark_result_from_artifact(
513516
artifact: Dict[str, Any],
514517
benchmark_config: Dict[str, str],
518+
job_report: Any,
515519
) -> List[Any]:
516520
job_name = artifact.get("job_name", "")
517521
artifact_type = artifact.get("type", "")
@@ -532,7 +536,9 @@ def extract_benchmark_result_from_artifact(
532536
)
533537
if not benchmark_results:
534538
return []
535-
return transform(app_type, benchmark_results, benchmark_config, job_name)
539+
return transform(
540+
app_type, benchmark_results, benchmark_config, job_name, job_report
541+
)
536542

537543

538544
def get_app_type(type: str):
@@ -674,7 +680,7 @@ def process_benchmark_results(content: Any, app: str, benchmark_configs: str):
674680
for job_artifact in job_artifacts:
675681
# generate result for each schema
676682
results = extract_benchmark_result_from_artifact(
677-
job_artifact, benchmark_config
683+
job_artifact, benchmark_config, job_report
678684
)
679685
all_benchmark_results.extend(results)
680686
return all_benchmark_results

backends/cadence/aot/fuse_ops.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -526,34 +526,14 @@ class FuseCascadedViewOps(ExportPass):
526526
Fuse a cascaded chain of view ops
527527
"""
528528

529-
# Find a chain of view ops, and fuse them into a single permute op.
530-
531529
def fuse_cascaded_view_ops(self, graph_module: torch.fx.GraphModule):
532-
graph = graph_module.graph
533-
for node in graph.nodes:
534-
# We are only interested in view ops
535-
if node.target != exir_ops.edge.aten.view_copy.default:
536-
continue
537-
538-
# Get the cascaded chain of view ops starting at node
539-
cascaded_view_ops = get_cascaded_ops(
540-
[node], [exir_ops.edge.aten.view_copy.default]
541-
)
542-
# The chain must have more than 1 node
543-
if len(cascaded_view_ops) == 1:
530+
view_target = exir_ops.edge.aten.view_copy.default
531+
for view_node in graph_module.graph.find_nodes(op="call_function", target=view_target, sort=True):
532+
input_view = view_node.args[0]
533+
if input_view.op != "call_function" or input_view.target != view_target:
544534
continue
545535

546-
last_view_node = cascaded_view_ops[-1]
547-
with graph.inserting_before(last_view_node):
548-
new_view = graph.call_function(
549-
exir_ops.edge.aten.view_copy.default,
550-
args=(node.args[0], last_view_node.args[1]),
551-
)
552-
last_view_node.replace_all_uses_with(new_view)
553-
554-
# Now erase the chain
555-
for v in reversed(cascaded_view_ops):
556-
graph.erase_node(v)
536+
view_node.replace_input_with(input_view, input_view.args[0])
557537

558538
graph_module.recompile()
559539

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,26 @@ def forward(self, x):
222222
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1
223223
)
224224

225+
def test_view_fusion_branched(self):
226+
class ViewFusion(torch.nn.Module):
227+
def forward(self, x):
228+
y = x.view([1, 8, 15])
229+
z = y.view([1, 1, 120])
230+
t = y.view([120, 1, 1])
231+
return z, t
232+
233+
x = torch.randn(8, 5, 3)
234+
graph_module = (
235+
compiler.export_to_cadence(ViewFusion(), (x,))
236+
.exported_program()
237+
.graph_module
238+
)
239+
graph_module.graph.eliminate_dead_code()
240+
# z and t should be fused and y should be eliminated.
241+
self.assertEqual(
242+
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 2
243+
)
244+
225245
def test_force_quant_dequant_fusion(self):
226246
class M(torch.nn.Module):
227247
def __init__(self):

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
179179
return utils::kChannelsPacked;
180180
}
181181

182+
bool ComputeGraph::device_name_contains(const char* substr) {
183+
return context_->adapter_ptr()->device_name().find(substr) !=
184+
std::string::npos;
185+
}
186+
182187
void ComputeGraph::check_no_active_value_ptrs() {
183188
VK_CHECK_COND(
184189
values_in_use_ == 0,

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,15 @@ class ComputeGraph final {
443443
utils::GPUMemoryLayout suggested_memory_layout(
444444
const std::vector<int64_t>& sizes);
445445

446+
inline bool device_is_adreno() {
447+
return context_->adapter_ptr()->device_type() == vkapi::DeviceType::ADRENO;
448+
}
449+
const std::string& device_name() {
450+
return context()->adapter_ptr()->device_name();
451+
}
452+
453+
bool device_name_contains(const char* substr);
454+
446455
//
447456
// Graph Building
448457
//
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
15+
16+
#define TILE_ROWS ${TILE_ROWS}
17+
18+
#define NGROUPS 8
19+
#define NWORKERS 8
20+
21+
${define_required_extensions(DTYPE)}
22+
23+
$if WEIGHT_STORAGE == "buffer":
24+
${define_required_extensions("int8")}
25+
26+
#extension GL_EXT_control_flow_attributes : require
27+
28+
layout(std430) buffer;
29+
30+
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
31+
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)}
32+
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
33+
${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)}
34+
35+
layout(push_constant) uniform restrict Block {
36+
ivec4 out_sizes;
37+
ivec4 in_sizes;
38+
ivec4 weight_sizes;
39+
};
40+
41+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
42+
43+
shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS];
44+
45+
void main() {
46+
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
47+
const uint out_col = gl_GlobalInvocationID.x << 2;
48+
49+
const int gid = int(gl_LocalInvocationID.x); // group id
50+
const int wid = int(gl_LocalInvocationID.z); // worker id
51+
52+
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
53+
return;
54+
}
55+
56+
VEC4_T a[TILE_ROWS];
57+
VEC4_T b[4];
58+
VEC4_T local_c[TILE_ROWS];
59+
60+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
61+
local_c[i] = VEC4_T(0.0);
62+
}
63+
64+
$if SCALES_STORAGE == "buffer":
65+
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
66+
$else:
67+
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));
68+
69+
for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) {
70+
// Preload t_weight
71+
[[unroll]] for (int i = 0; i < 4; i++) {
72+
$if WEIGHT_STORAGE == "buffer":
73+
b[i] = t_weight[((pos + i) * weight_sizes.x + out_col) >> 2];
74+
$else:
75+
b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0));
76+
}
77+
// Preload t_in
78+
for (int i = 0; i < TILE_ROWS; i++) {
79+
$if IN_STORAGE == "buffer":
80+
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2];
81+
$else:
82+
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
83+
}
84+
85+
// Accumulate partial output
86+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
87+
local_c[i] += a[i].x * b[0] +
88+
a[i].y * b[1] +
89+
a[i].z * b[2] +
90+
a[i].w * b[3];
91+
}
92+
}
93+
94+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
95+
partial_c[gid][wid][i] = local_c[i];
96+
}
97+
98+
memoryBarrierShared();
99+
barrier();
100+
101+
if (wid != 0) {
102+
return;
103+
}
104+
105+
VEC4_T c[TILE_ROWS];
106+
107+
for (int row = 0; row < TILE_ROWS; ++row) {
108+
c[row] = VEC4_T(0.0);
109+
[[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) {
110+
c[row] += partial_c[gid][worker][row];
111+
}
112+
}
113+
114+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
115+
$if OUT_STORAGE == "buffer":
116+
if (out_row + i < out_sizes.y) {
117+
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
118+
}
119+
$else:
120+
imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales);
121+
}
122+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q_8w_linear_coop:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
IN_STORAGE: texture3d
11+
OUT_STORAGE: texture3d
12+
WEIGHT_STORAGE: texture2d
13+
SCALES_STORAGE: texture2d
14+
TILE_ROWS: 4
15+
generate_variant_forall:
16+
TILE_ROWS:
17+
- VALUE: 1
18+
SUFFIX: o4x1
19+
shader_variants:
20+
- NAME: q_8w_linear_coop_texture3d_texture3d_texture2d_texture2d_float
21+
- NAME: q_8w_linear_coop_buffer_buffer_texture2d_texture2d_float
22+
IN_STORAGE: buffer
23+
OUT_STORAGE: buffer
24+
- NAME: q_8w_linear_coop_buffer_buffer_buffer_buffer_float
25+
IN_STORAGE: buffer
26+
OUT_STORAGE: buffer
27+
WEIGHT_STORAGE: buffer
28+
SCALES_STORAGE: buffer

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717

1818
${define_required_extensions(DTYPE)}
1919

20-
$if STORAGE == "buffer":
20+
$if WEIGHT_STORAGE == "buffer":
2121
${define_required_extensions("int8")}
2222

2323
#extension GL_EXT_control_flow_attributes : require
2424

2525
layout(std430) buffer;
2626

27-
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
28-
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
29-
${layout_declare_tensor(B, "r", "t_weight", "int8", STORAGE, is_scalar_array=False)}
30-
${layout_declare_tensor(B, "r", "t_scales", DTYPE, STORAGE, is_scalar_array=False)}
27+
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
28+
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)}
29+
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
30+
${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)}
3131

3232

3333
layout(push_constant) uniform restrict Block {
@@ -50,10 +50,10 @@ void main() {
5050
VEC4_T b[4];
5151
VEC4_T c[TILE_ROWS];
5252

53-
$if STORAGE == "buffer":
53+
$if SCALES_STORAGE == "buffer":
5454
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
5555
$else:
56-
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0));
56+
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));
5757

5858
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
5959
c[i] = VEC4_T(0.0);
@@ -62,30 +62,32 @@ void main() {
6262
for (int pos = 0; pos < in_sizes.x; pos += 4) {
6363
// Preload weight tensor
6464
[[unroll]] for (int i = 0; i < 4; i++) {
65-
$if STORAGE == "buffer":
66-
b[i] = t_weight[((pos + i) * B_sizes.x + out_col) >> 2];
65+
$if WEIGHT_STORAGE == "buffer":
66+
b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2];
6767
$else:
68-
b[i] = VEC4_T(texelFetch(t_weight, ivec3(out_col >> 2, pos + i, 0), 0));
68+
b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0));
6969
}
7070

7171
// Preload input tensor
7272
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
73-
$if STORAGE == "buffer":
74-
a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2];
73+
$if IN_STORAGE == "buffer":
74+
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2];
7575
$else:
7676
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
7777
}
7878

79-
// Compute partial output
79+
// Accumulate output
8080
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
8181
c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3];
8282
}
8383
}
8484

85-
// Store output tensor
85+
// Store to output tensor
8686
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
87-
$if STORAGE == "buffer":
88-
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
87+
$if OUT_STORAGE == "buffer":
88+
if (out_row + i < out_sizes.y) {
89+
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
90+
}
8991
$else:
9092
imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales);
9193
}

0 commit comments

Comments
 (0)