Skip to content

Commit 8821346

Browse files
committed
Add config validation check for KThreads
1 parent d794dc7 commit 8821346

File tree

2 files changed

+77
-6
lines changed

2 files changed

+77
-6
lines changed

bench_mlp.sh

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/bin/bash -e
2+
3+
# if [ -z "$PATH_TO_JEMALLOC" ]; then
4+
# echo "PATH_TO_JEMALLOC not set."
5+
# exit 1
6+
# fi
7+
8+
# export PATH_TO_JEMALLOC=/home/yifei/ipex_env/jemalloc/lib/libjemalloc.so
9+
export NUM_THREADS=56
10+
11+
export LD_PRELOAD=${LD_PRELOAD}:${PATH_TO_JEMALLOC}
12+
13+
for arg in "$@"; do
14+
case $arg in
15+
--bench)
16+
MODE=P
17+
;;
18+
--tune)
19+
MODE=T
20+
;;
21+
*)
22+
echo Unsupported option: $arg
23+
exit 1
24+
;;
25+
esac
26+
done
27+
28+
if [ -z "$MODE" ]; then
29+
echo "Mode not set."
30+
exit 1
31+
fi
32+
33+
if [ -z "$NUM_THREADS" ]; then
34+
echo "NUM_THREADS not set."
35+
exit 1
36+
fi
37+
38+
export OMP_NUM_THREADS=${NUM_THREADS}
39+
export START_NODE=0
40+
export END_NODE=$(($OMP_NUM_THREADS-1))
41+
42+
if [ "$MODE" = "P" ]; then
43+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=P --driver=pattern --case mlp --batch_size=128 --hidden_size_list=16x512x256x128 --has_bias=1x1x1 --act_type=relu --warm_up 500 --repeat 5000
44+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=P --driver=pattern --case mlp --batch_size=128 --hidden_size_list=512x1024x1024x512x256 --has_bias=1x1x1x1 --act_type=relu --warm_up 500 --repeat 5000
45+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=P --driver=pattern --case mlp --batch_size=32 --hidden_size_list=4096x4096x11008x4096 --has_bias=1x1x1 --act_type=relu --warm_up 500 --repeat 500
46+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=P --driver=pattern --case mlp --batch_size=128 --hidden_size_list=4096x4096x11008x4096 --has_bias=1x1x1 --act_type=relu --warm_up 500 --repeat 500
47+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=P --driver=pattern --case mlp --batch_size=128 --hidden_size_list=16x512x256x128 --dtype=bf16 --has_bias=1x1x1 --act_type=relu --warm_up 500 --repeat 5000
48+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=P --driver=pattern --case mlp --batch_size=128 --hidden_size_list=512x1024x1024x512x256 --dtype=bf16 --has_bias=1x1x1x1 --act_type=relu --warm_up 500 --repeat 5000
49+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=P --driver=pattern --case mlp --batch_size=32 --hidden_size_list=4096x4096x11008x4096 --dtype=bf16 --has_bias=1x1x1 --act_type=relu --warm_up 500 --repeat 2000
50+
numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=P --driver=pattern --case mlp --batch_size=128 --hidden_size_list=4096x4096x11008x4096 --dtype=bf16 --has_bias=1x1x1 --act_type=relu --warm_up 500 --repeat 2000
51+
else
52+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=128 --hidden_size_list=4096x4096 --has_bias=1 --dtype=bf16 --act_type=relu --warm_up 200 --repeat 2000
53+
numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=128 --hidden_size_list=4096x11008 --has_bias=1 --dtype=bf16 --act_type=relu --warm_up 100 --repeat 500
54+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=32 --hidden_size_list=4096x4096 --has_bias=1 --dtype=bf16 --act_type=relu --warm_up 500 --repeat 5000
55+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=128 --hidden_size_list=1024x512 --has_bias=1 --act_type=relu --warm_up 500 --repeat 5000
56+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=128 --hidden_size_list=1024x512x256 --has_bias=1x1 --act_type=relu --warm_up 500 --repeat 5000
57+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=128 --hidden_size_list=512x1024x1024x512x256 --has_bias=1x1x1x1 --act_type=relu --warm_up 500 --repeat 5000
58+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=32 --hidden_size_list=4096x4096x11008x4096 --has_bias=1x1x1 --act_type=relu --warm_up 500 --repeat 500
59+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=128 --hidden_size_list=4096x4096x11008x4096 --has_bias=1x1x1 --act_type=relu --warm_up 500 --repeat 500
60+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=128 --hidden_size_list=16x512x256x128 --dtype=bf16 --has_bias=1x1x1 --act_type=relu --warm_up 500 --repeat 5000
61+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=128 --hidden_size_list=512x1024x1024x512x256 --dtype=bf16 --has_bias=1x1x1x1 --act_type=relu --warm_up 500 --repeat 5000
62+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=32 --hidden_size_list=4096x4096x11008x4096 --dtype=bf16 --has_bias=1x1x1 --act_type=relu --warm_up 500 --repeat 2000
63+
# numactl -C $START_NODE-$END_NODE -m 0 python -m benchgc --mode=T --driver=pattern --case mlp --batch_size=128 --hidden_size_list=4096x4096x11008x4096 --dtype=bf16 --has_bias=1x1x1 --act_type=relu --warm_up 500 --repeat 2000
64+
fi

lib/gc/Analysis/MatmulConfigAnalysis.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
3737
return ss;
3838
}
3939

40-
bool validateConfig(const MatmulConfig &cfg) {
40+
bool validateConfig(const MatmulConfig &cfg, ArrayRef<uint32_t> shape = {}) {
4141
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
4242
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
4343
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
@@ -47,6 +47,12 @@ bool validateConfig(const MatmulConfig &cfg) {
4747
cfg.NBlock % cfg.innerMostNBlock != 0 ||
4848
cfg.KBlock % cfg.innerMostKBlock != 0)
4949
return false;
50+
if (!shape.empty()) {
51+
// KThreads will not shrink automatically
52+
// K is shape[2]
53+
if (llvm::divideCeil(shape[2], cfg.KBlock) < cfg.KThreads)
54+
return false;
55+
}
5056
return true;
5157
}
5258

@@ -179,7 +185,7 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
179185
ArrayRef<uint32_t> shape,
180186
const MatmulConfig &config,
181187
CPUTargetDescriptionAnalysis &sysDesc) {
182-
assert(validateConfig(config) && "config is invalid");
188+
assert(validateConfig(config, shape) && "config is invalid");
183189
assert(shape.size() >= 3 && "shape.size() should >= 3");
184190
uint32_t M = shape[0], N = shape[1];
185191
double cost = 0;
@@ -361,7 +367,8 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
361367
}
362368

363369
// read the config from the attributes for tuning
364-
bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
370+
bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs,
371+
ArrayRef<uint32_t> shape) {
365372
size_t cfgItemCnt = 0;
366373
for (const auto &attr : attrs) {
367374
if (attr.getName() == "KBlock") {
@@ -393,7 +400,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
393400
cfgItemCnt++;
394401
}
395402
}
396-
if (validateConfig(config)) {
403+
if (validateConfig(config, shape)) {
397404
return cfgItemCnt == 9;
398405
} else {
399406
LLVM_DEBUG(llvm::dbgs() << "The predefined config is invalid\n");
@@ -483,7 +490,8 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
483490

484491
// try to read the config from the attributes
485492
SmallVector<NamedAttribute> attrs(linalgOp->getAttrs());
486-
bool hasPredefinedConfig = readConfigFromAttrs(config, attrs);
493+
bool hasPredefinedConfig =
494+
readConfigFromAttrs(config, attrs, SmallVector<uint32_t>{M, N, K});
487495

488496
// if there is a given config, skip the cost model
489497
if (!hasPredefinedConfig) {
@@ -520,7 +528,6 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
520528
hasConfig = true;
521529
}
522530

523-
assert(validateConfig(config) && "config is invalid");
524531
return config;
525532
}
526533
} // namespace gc

0 commit comments

Comments
 (0)