Skip to content

Commit 34a6571

Browse files
Maximilian.Sphaniarnab
authored andcommitted
[SYSTEMDS-3803] DML-bodied Util Function for Transposing ABCD to ACBD
This patch adds a simple util function for transposing matrices in a specified way, which is required for multi-head attention implementation. Closes #2151
1 parent a487be0 commit 34a6571

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

scripts/nn/util.dml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,28 @@ top_k2d = function(matrix[double] X, int k, int C, int Hin, int Win)
380380
indices = transpose_NCHW_to_CNHW(indices_K_NHW, N)
381381
}
382382

383+
transpose_ABCD_to_ACBD = function(matrix[double] X, int B, int C)
384+
return (matrix[double] out) {
385+
/*
386+
* Reshape util for tensors in ABCD format.
387+
* Transposes the 2nd and 3rd axes.
388+
*
389+
* Inputs:
390+
* - X: Inputs, of shape (A, B*C*D).
391+
* - B: Dimension of 2nd axis.
392+
* - C: Dimension of 3rd axis.
393+
*
394+
* Outputs:
395+
* - out: Outputs with the 2nd and 3rd axes transposed, of
396+
* shape (A, C*B*D).
397+
*/
398+
A = nrow(X)
399+
BCD = ncol(X)
400+
401+
# use NCHW_to_CNHW for X: (A, B*C*D) -> (B, A*C*D)
402+
X_BACD = transpose_NCHW_to_CNHW(X, B)
403+
# use NCHW_to_CNHW for X: (B, A*C*D) -> (A*C, B*D)
404+
X_ACBD = transpose_NCHW_to_CNHW(X_BACD, A*C)
405+
# reshape X: (A*C, B*D) -> (A, C*B*D)
406+
out = matrix(X_ACBD, rows=A, cols=BCD)
407+
}

src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ public void transpose_NCHW_to_CNHW() {
108108
run("transpose_NCHW_to_CNHW.dml");
109109
}
110110

111+
@Test
112+
public void transpose_ABCD_to_ACBD() {
113+
run("transpose_ABCD_to_ACBD.dml");
114+
}
115+
111116
@Test
112117
public void logcosh(){
113118
run("logcosh.dml");
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
source("src/test/scripts/applications/nn/util.dml") as test_util
23+
source("scripts/nn/util.dml") as util
24+
25+
26+
transpose_ABCD_to_ACBD = function() {
27+
/*
28+
* Test for `transpose_ABCD_to_ACBD` function.
29+
*/
30+
print("Testing transpose_ABCD_to_ACBD function.")
31+
32+
# Generate data
33+
A = 2
34+
B = 3
35+
C = 4
36+
D = 5
37+
X = matrix(seq(1, A*B*C*D), rows=A, cols=B*C*D)
38+
39+
out = util::transpose_ABCD_to_ACBD(X, B, C)
40+
41+
target =
42+
matrix("1 2 3 4 5 21 22 23 24 25 41 42 43 44 45
43+
6 7 8 9 10 26 27 28 29 30 46 47 48 49 50
44+
11 12 13 14 15 31 32 33 34 35 51 52 53 54 55
45+
16 17 18 19 20 36 37 38 39 40 56 57 58 59 60
46+
47+
61 62 63 64 65 81 82 83 84 85 101 102 103 104 105
48+
66 67 68 69 70 86 87 88 89 90 106 107 108 109 110
49+
71 72 73 74 75 91 92 93 94 95 111 112 113 114 115
50+
76 77 78 79 80 96 97 98 99 100 116 117 118 119 120",
51+
rows=A, cols=C*B*D)
52+
53+
# Equivalency check
54+
test_util::check_all_close(out, target, 1e-10)
55+
}
56+
57+
58+
transpose_ABCD_to_ACBD_single_val = function() {
59+
/*
60+
* Test for `transpose_ABCD_to_ACBD` function,
61+
* transposing a single value matrix.
62+
*/
63+
print("Testing transpose_ABCD_to_ACBD function with single value.")
64+
65+
# Generate data
66+
A = 1
67+
B = 1
68+
C = 1
69+
D = 1
70+
X = matrix(seq(1, A*B*C*D), rows=A, cols=B*C*D)
71+
72+
out = util::transpose_ABCD_to_ACBD(X, B, C)
73+
74+
target = X
75+
76+
# Equivalency check
77+
test_util::check_all_close(out, target, 1e-10)
78+
}
79+
80+
81+
transpose_ABCD_to_ACBD()
82+
transpose_ABCD_to_ACBD_single_val()

0 commit comments

Comments
 (0)