Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions demo/demo_llama3-8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def get_rms_linear_kernel(num_tokens, output_dim):
O = graph.matmul(D, W)
graph.mark_output(O)
return graph.superoptimize(config="mlp")

def get_chunk_linear_kernel(num_tokens, intermediate_size, output_dim):
graph = mi.new_kernel_graph()
X13 = graph.new_input(dims=(num_tokens, intermediate_size * 2))
W2 = graph.new_input(dims=(intermediate_size, 4096))
X1, X3 = graph.chunk(X13, 2, 1)
output = graph.matmul(X1, W2)
graph.mark_output(output)
graph.mark_output(X3)
return graph.superoptimize(config="mlp")

def mirage_llama(X, Wqkv, Wo, W13, W2, Kcache, Vcache, kernels):
func = kernels[0]
Expand All @@ -33,13 +43,13 @@ def mirage_llama(X, Wqkv, Wo, W13, W2, Kcache, Vcache, kernels):
Xk = Xk.view(Xk.shape[0], n_local_kv_heads, head_dim)
Xv = Xv.view(Xv.shape[0], n_local_kv_heads, head_dim)
output = flashinfer.single_prefill_with_kv_cache(Xq, Kcache, Vcache, causal=True)
X = torch.matmul(output.reshape(output_shape), Wo)
output_reshaped = output.reshape(output_shape)
X = torch.matmul(output_reshaped, Wo)
func = kernels[1]
outputs = func(inputs=[X, W13])
X13 = outputs[0]
X1, X3 = X13.chunk(2, -1)
output = torch.matmul(X1, W2)
return output
func = kernels[2]
return func(inputs=[X13, W2])[0]

if __name__ == "__main__":
X = torch.randn(num_tokens, 4096, dtype=torch.float16, device='cuda:0')
Expand All @@ -52,7 +62,8 @@ def mirage_llama(X, Wqkv, Wo, W13, W2, Kcache, Vcache, kernels):

k1 = get_rms_linear_kernel(num_tokens, Wqkv.shape[-1])
k2 = get_rms_linear_kernel(num_tokens, W13.shape[-1])
kernels = [k1, k2]
k3 = get_chunk_linear_kernel(num_tokens, intermediate_size, W2.shape[-1])
kernels = [k1, k2, k3]

for _ in range(16):
mirage_llama(X, Wqkv, Wo, W13, W2, Kcache, Vcache, kernels)
Expand Down
2 changes: 1 addition & 1 deletion include/mirage/kernel/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class Graph {
KNOperator *create_all_reduce_op(DTensor const &input, bool inplace);
// chunk operator
std::vector<DTensor> chunk(DTensor const &input, int chunk_size, int dim);
int chunk(DTensor const *input, int chunk_size, int dim);
std::vector<DTensor *> chunk(DTensor const *input, int chunk_size, int dim);
KNOperator *create_chunk_op(DTensor const &input, int chunk_size, int dim);
// customized operator
std::vector<DTensor> customized(std::vector<DTensor> const &inputs,
Expand Down
3 changes: 3 additions & 0 deletions include/mirage/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class KernelGraphGenerator {
std::atomic<int> num_tasks;
size_t max_depth;

// count number of chunk ops
std::vector<int> num_chunk_ops = std::vector<int>(4);

//
std::unordered_map<std::string, bool> seen_patterns;

Expand Down
1 change: 1 addition & 0 deletions include/mirage/search/search_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct SearchContext {
std::shared_ptr<kernel::Graph> kn_graph;
std::shared_ptr<threadblock::Graph> tb_graph;
SearchLevel level;
std::vector<int> ctx_num_chunk_ops = std::vector<int>(3);
};

void from_json(json const &j, SearchContext &c);
Expand Down
37 changes: 37 additions & 0 deletions include/mirage/threadblock/chunk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2023-2024 CMU
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "mirage/threadblock/operator.h"

namespace mirage {
namespace threadblock {

using namespace cutlass;

class TBChunkOp : public TBOperator {
public:
TBChunkOp(Graph *bgraph, STensor const &input, int chunk_size, int dim);
~TBChunkOp();

operator json() const override;

public:
int chunk_size, chunk_dim;
};

} // namespace threadblock
} // namespace mirage
136 changes: 136 additions & 0 deletions include/mirage/threadblock/cuda/chunk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/* Copyright 2023-2024 CMU
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "mirage/utils/fingerprint_functions.h"

namespace mirage {
namespace threadblock {

using namespace cutlass;
using namespace mirage::type;
using namespace mirage::config;
using namespace mirage::utils;

template <typename ElementType>
class ChunkExecutor {
public:
CUTLASS_DEVICE
ChunkExecutor(ElementType *input_ptr,
ElementType *output1_ptr,
ElementType *output2_ptr,
int3 input_shape,
int chunk_size,
int chunk_dim,
int thread_id,
int num_threads) {
// determine the shape of the two output ptrs
int3 output_shape = {
chunk_dim == 0 ? input_shape.x / chunk_size : input_shape.x,
chunk_dim == 1 ? input_shape.y / chunk_size : input_shape.y,
chunk_dim == 2 ? input_shape.z / chunk_size : input_shape.z};
int output_num_elements = input_shape.x * input_shape.y * input_shape.z;

for (int i = 0; i < output_num_elements; i += num_threads) {
int input_i = i / (input_shape.y * input_shape.z);
int input_j = (i % (input_shape.y * input_shape.z)) / input_shape.z;
int input_k = i % input_shape.z;
if (chunk_dim == 0) {
if (input_i < output_shape.x) {
output1_ptr[i] = input_ptr[i];
} else {
int i2 =
((input_i - output_shape.x) * (output_shape.y * output_shape.z)) +
(input_j * output_shape.z) + input_k;
output2_ptr[i2] = input_ptr[i];
}
} else if (chunk_dim == 1) {
if (input_j < output_shape.y) {
output1_ptr[i] = input_ptr[i];
} else {
int i2 = (input_i * (output_shape.y * output_shape.z)) +
((input_j - output_shape.y) * output_shape.z) + input_k;
output2_ptr[i2] = input_ptr[i];
}
} else { // chunk_dim == 2
if (input_k < output_shape.z) {
output1_ptr[i] = input_ptr[i];
} else {
int i2 = (input_i * (output_shape.y * output_shape.z)) +
(input_j * output_shape.z) + (input_k - output_shape.z);
output2_ptr[i2] = input_ptr[i];
}
}
}
};
};

class TBChunkFingerprinter {
public:
CUTLASS_DEVICE
TBChunkFingerprinter(FPType *input_ptr,
FPType *output1_ptr,
FPType *output2_ptr,
int3 input_shape,
int chunk_size,
int chunk_dim,
int thread_id,
int num_threads) {
int3 output_shape = {
chunk_dim == 0 ? input_shape.x / chunk_size : input_shape.x,
chunk_dim == 1 ? input_shape.y / chunk_size : input_shape.y,
chunk_dim == 2 ? input_shape.z / chunk_size : input_shape.z};
int output_num_elements = input_shape.x * input_shape.y * input_shape.z;

FPType one = 1;
for (int i = 0; i < output_num_elements; i += num_threads) {
int input_i = i / (input_shape.y * input_shape.z);
int input_j = (i % (input_shape.y * input_shape.z)) / input_shape.z;
int input_k = i % input_shape.z;
if (chunk_dim == 0) {
if (input_i < output_shape.x) {
output1_ptr[i] = compute_mul_fingerprint(input_ptr[i], one);
} else {
int i2 =
((input_i - output_shape.x) * (output_shape.y * output_shape.z)) +
(input_j * output_shape.z) + input_k;
output2_ptr[i2] = compute_mul_fingerprint(input_ptr[i], one);
}
} else if (chunk_dim == 1) {
if (input_j < output_shape.y) {
output1_ptr[i] = compute_mul_fingerprint(input_ptr[i], one);
} else {
int i2 = (input_i * (output_shape.y * output_shape.z)) +
((input_j - output_shape.y) * output_shape.z) + input_k;
output2_ptr[i2] = compute_mul_fingerprint(input_ptr[i], one);
}
} else { // chunk_dim == 2
if (input_k < output_shape.z) {
output1_ptr[i] = compute_mul_fingerprint(input_ptr[i], one);
} else {
int i2 = (input_i * (output_shape.y * output_shape.z)) +
(input_j * output_shape.z) + (input_k - output_shape.z);
output2_ptr[i2] = compute_mul_fingerprint(input_ptr[i], one);
}
}
}
}
};

} // namespace threadblock
} // namespace mirage
5 changes: 5 additions & 0 deletions include/mirage/threadblock/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ class Graph {
STensor *reduction(STensor const *A, int dim);
TBOperator *create_reduction_op(STensor const &A, int dim);

// chunk operator
std::vector<STensor> chunk(STensor const &A, int chunk_size, int dim);
std::vector<STensor *> chunk(STensor const *A, int chunk_size, int dim);
TBOperator *create_chunk_op(STensor const &A, int chunk_size, int dim);

// reduction_to_dimx operator
STensor reduction_to_dimx(STensor const &A, int dim);
TBOperator *create_reduction_to_dimx_op(STensor const &A, int dim);
Expand Down
67 changes: 67 additions & 0 deletions include/mirage/threadblock/serializer/chunk_serializer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/* Copyright 2023-2024 CMU
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <vector_types.h>

namespace mirage {
namespace threadblock {

CUTLASS_HOST_DEVICE
void deserialize_chunk_op_parameters(int const *params,
int &param_idx,
int3 &input_shape,
int &chunk_size,
int &chunk_dim,
int &input_smem_offset,
int &output1_smem_offset,
int &output2_smem_offset) {
input_shape.x = params[param_idx++];
input_shape.y = params[param_idx++];
input_shape.z = params[param_idx++];

chunk_size = params[param_idx++];
chunk_dim = params[param_idx++];

input_smem_offset = params[param_idx++];
output1_smem_offset = params[param_idx++];
output2_smem_offset = params[param_idx++];
}

inline void serialize_chunk_op_parameters(int *params,
int &param_idx,
int3 input_shape,
int chunk_size,
int chunk_dim,
int input_smem_offset,
int output1_smem_offset,
int output2_smem_offset) {
params[param_idx++] = input_shape.x;
params[param_idx++] = input_shape.y;
params[param_idx++] = input_shape.z;

params[param_idx++] = chunk_size;
params[param_idx++] = chunk_dim;

params[param_idx++] = input_smem_offset;
params[param_idx++] = output1_smem_offset;
params[param_idx++] = output2_smem_offset;

assert(param_idx <= NewKernelParams::MAX_NUM_PARAMETERS);
}

} // namespace threadblock
} // namespace mirage
Loading
Loading