Skip to content

Commit eaebdfc

Browse files
committed
Add Multi Card Command.
1 parent a914ad9 commit eaebdfc

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

examples/benchmark/mmlu/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,27 @@ tar xf data.tar
1616

1717
在当前目录下运行以下脚本:
1818

19+
- 单卡运行
1920
```
21+
export CUDA_VISIBLE_DEVICES=0
2022
python eval.py \
2123
--model_name_or_path /path/to/your/model \
2224
--temperature 0.2 \
2325
--ntrain 5 \
2426
--output_dir ${output_path} \
2527
--dtype 'float16'
2628
```
29+
- 多卡运行
30+
```
31+
export CUDA_VISIBLE_DEVICES=0,1,2,3
32+
python -m paddle.distributed.fleet.launch eval.py \
33+
--model_name_or_path /path/to/your/model \
34+
--temperature 0.2 \
35+
--ntrain 5 \
36+
--output_dir ${output_path} \
37+
--dtype 'float16' \
38+
--tensor_parallel_degree 4
39+
```
2740

2841
参数说明
2942

examples/benchmark/mmlu/eval.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818

1919
import numpy as np
20+
import paddle
2021
import pandas as pd
2122
from categories import categories, subcategories
2223
from evaluator import ModelEvaluator
@@ -29,9 +30,9 @@ def main(args, evaluator):
2930
[f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f]
3031
)
3132
if not os.path.exists(args.output_dir):
32-
os.makedirs(args.output_dir)
33+
os.makedirs(args.output_dir, exist_ok=True)
3334
if not os.path.exists(os.path.join(args.output_dir, "results_{}".format(args.model_name_or_path))):
34-
os.makedirs(os.path.join(args.output_dir, "results_{}".format(args.model_name_or_path)))
35+
os.makedirs(os.path.join(args.output_dir, "results_{}".format(args.model_name_or_path)), exist_ok=True)
3536

3637
all_cors = []
3738
subcat_cors = {subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists}
@@ -95,15 +96,25 @@ def main(args, evaluator):
9596
parser.add_argument("--data_dir", "-d", type=str, default="data")
9697
parser.add_argument("--output_dir", type=str, default="results")
9798
parser.add_argument("--dtype", default="float32", type=str)
99+
parser.add_argument("--tensor_parallel_degree", default=1, type=int)
98100

99101
args = parser.parse_args()
100102
print(args)
101103

104+
if args.tensor_parallel_degree > 1:
105+
strategy = paddle.distributed.fleet.DistributedStrategy()
106+
strategy.hybrid_configs = {
107+
"mp_degree": args.tensor_parallel_degree,
108+
}
109+
# Set control in tensor parallel
110+
strategy.tensor_parallel_configs = {"tensor_init_seed": 1234}
111+
paddle.distributed.fleet.init(is_collective=True, strategy=strategy)
102112
evaluator = ModelEvaluator(
103113
model_name_or_path=args.model_name_or_path,
104114
ntrain=args.ntrain,
105115
temperature=args.temperature,
106116
dtype=args.dtype,
117+
tensor_parallel_degree=args.tensor_parallel_degree,
107118
)
108119

109120
main(args, evaluator=evaluator)

examples/benchmark/mmlu/evaluator.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,21 @@
2323

2424

2525
class ModelEvaluator(object):
26-
def __init__(self, model_name_or_path, ntrain, temperature=0.2, dtype="float32"):
26+
def __init__(self, model_name_or_path, ntrain, temperature=0.2, dtype="float32", tensor_parallel_degree=1):
2727
self.model_name_or_path = model_name_or_path
2828
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
29-
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, dtype=dtype, low_cpu_mem_usage=True)
29+
self.tensor_parallel_degree = tensor_parallel_degree
30+
if self.tensor_parallel_degree > 1:
31+
self.model = AutoModelForCausalLM.from_pretrained(
32+
model_name_or_path,
33+
dtype=dtype,
34+
low_cpu_mem_usage=True,
35+
tensor_parallel_output=False,
36+
tensor_parallel_degree=self.tensor_parallel_degree,
37+
tensor_parallel_rank=paddle.distributed.get_rank(),
38+
)
39+
else:
40+
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, dtype=dtype, low_cpu_mem_usage=True)
3041
self.model.eval()
3142
self.generation_config = dict(
3243
temperature=temperature,

0 commit comments

Comments
 (0)