@@ -45,7 +45,9 @@ def __init__(
45
45
max_tokens : int = 1024 ,
46
46
use_kernel : bool = False ,
47
47
use_meta_tensor : bool = False ,
48
- injection_policy = None ,
48
+ test_hybrid_engine : bool = False ,
49
+ save_mp_checkpoint_path : bool = False ,
50
+ # injection_policy=None,
49
51
ds_inference_kwargs : Optional [Dict [str , Any ]] = None ,
50
52
** from_pretrained_kwargs ,
51
53
):
@@ -60,8 +62,10 @@ def __init__(
60
62
self .max_tokens = max_tokens
61
63
self .use_kernel = use_kernel
62
64
self .use_meta_tensor = use_meta_tensor
65
+ self .test_hybrid_engine = test_hybrid_engine
66
+ self .save_mp_checkpoint_path = save_mp_checkpoint_path
63
67
# TODO: Allow conversion from strings (need to do dynamic imports)
64
- self .injection_policy = injection_policy
68
+ # self.injection_policy = injection_policy
65
69
self .ds_inference_kwargs = ds_inference_kwargs
66
70
67
71
if self .use_kernel :
@@ -114,6 +118,8 @@ def _generate_checkpoint_json(
114
118
for entry in Path (repo_root ).rglob ("*.[bp][it][n]" )
115
119
if entry .is_file ()
116
120
]
121
+
122
+ # BOOLM ?!
117
123
data = {"type" : "BLOOM" ,
118
124
"checkpoints" : file_list , "version" : 1.0 }
119
125
json .dump (data , f )
@@ -170,58 +176,78 @@ def load_model(self, model_id: str) -> "PreTrainedModel":
170
176
return model
171
177
172
178
def postprocess_model (self , model : "PreTrainedModel" ) -> "PreTrainedModel" :
173
- from transformers import GPTNeoXForCausalLM , LlamaForCausalLM
174
-
175
- injection_policy = self .injection_policy
176
- # TODO: remove those later when deepspeed master is updated
177
- if injection_policy is None and not self .use_kernel :
178
- if isinstance (model , GPTNeoXForCausalLM ):
179
- from transformers import GPTNeoXLayer
180
-
181
- injection_policy = {
182
- GPTNeoXLayer : ("attention.dense" , "mlp.dense_4h_to_h" )
183
- }
184
- elif isinstance (model , LlamaForCausalLM ):
185
- from transformers .models .llama .modeling_llama import LlamaDecoderLayer
186
-
187
- injection_policy = {
188
- LlamaDecoderLayer : ("self_attn.o_proj" , "mlp.down_proj" )
189
- }
190
-
191
- if self .use_bettertransformer :
192
- from optimum .bettertransformer import BetterTransformer
193
-
194
- logger .info ("Transforming the model with BetterTransformer..." )
195
- model = BetterTransformer .transform (model )
196
-
197
- ds_kwargs = self .ds_inference_kwargs or {}
198
- ds_kwargs = ds_kwargs .copy ()
199
- ds_kwargs .update (
200
- dict (
201
- dtype = self .dtype ,
202
- mp_size = self .world_size ,
203
- replace_with_kernel_inject = self .use_kernel ,
204
- injection_policy = injection_policy ,
205
- max_tokens = self .max_tokens ,
206
- )
207
- )
208
179
if self .use_meta_tensor :
209
- ds_kwargs .update (
210
- dict (base_dir = self ._repo_root , checkpoint = self ._checkpoints_json )
211
- )
212
-
213
- logger .info (f"deepspeed.init_inference kwargs: { ds_kwargs } " )
214
- model = deepspeed .init_inference (
215
- model ,
216
- ** ds_kwargs ,
217
- )
180
+ ds_kwargs = dict (base_dir = self ._repo_root , checkpoint = self ._checkpoints_json )
181
+ else :
182
+ ds_kwargs = dict ()
183
+
184
+ # Use DeepSpeed Hybrid Engine for inference
185
+ if self .test_hybrid_engine :
186
+ ds_config = {"train_batch_size" : 2 , "fp16" : {"enabled" : True if self .dtype == torch .half else False }, "hybrid_engine" : {"enabled" : True }}
187
+ model , * _ = deepspeed .initialize (model = model , config = ds_config )
188
+ model .eval ()
189
+ # If not trying with the HuggingFace baseline, use DeepSpeed Inference Engine
190
+ else :
191
+ model = deepspeed .init_inference (model ,
192
+ dtype = self .dtype ,
193
+ mp_size = self .world_size ,
194
+ replace_with_kernel_inject = self .use_kernel ,
195
+ max_tokens = self .max_tokens ,
196
+ save_mp_checkpoint_path = self .save_mp_checkpoint_path ,
197
+ ** ds_kwargs
198
+ )
199
+ # from transformers import GPTNeoXForCausalLM, LlamaForCausalLM
200
+
201
+ # injection_policy = self.injection_policy
202
+ # # TODO: remove those later when deepspeed master is updated
203
+ # if injection_policy is None and not self.use_kernel:
204
+ # if isinstance(model, GPTNeoXForCausalLM):
205
+ # from transformers import GPTNeoXLayer
206
+
207
+ # injection_policy = {
208
+ # GPTNeoXLayer: ("attention.dense", "mlp.dense_4h_to_h")
209
+ # }
210
+ # elif isinstance(model, LlamaForCausalLM):
211
+ # from transformers.models.llama.modeling_llama import LlamaDecoderLayer
212
+
213
+ # injection_policy = {
214
+ # LlamaDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")
215
+ # }
216
+
217
+ # if self.use_bettertransformer:
218
+ # from optimum.bettertransformer import BetterTransformer
219
+
220
+ # logger.info("Transforming the model with BetterTransformer...")
221
+ # model = BetterTransformer.transform(model)
222
+
223
+ # ds_kwargs = self.ds_inference_kwargs or {}
224
+ # ds_kwargs = ds_kwargs.copy()
225
+ # ds_kwargs.update(
226
+ # dict(
227
+ # dtype=self.dtype,
228
+ # mp_size=self.world_size,
229
+ # replace_with_kernel_inject=self.use_kernel,
230
+ # injection_policy=injection_policy,
231
+ # max_tokens=self.max_tokens,
232
+ # )
233
+ # )
234
+ # if self.use_meta_tensor:
235
+ # ds_kwargs.update(
236
+ # dict(base_dir=self._repo_root, checkpoint=self._checkpoints_json)
237
+ # )
238
+
239
+ # logger.info(f"deepspeed.init_inference kwargs: {ds_kwargs}")
240
+ # model = deepspeed.init_inference(
241
+ # model,
242
+ # **ds_kwargs,
243
+ # )
218
244
219
245
if self .torch_compile and self .torch_compile ["backend" ]:
220
246
logger .info ("Compiling the model with torch.compile()..." )
221
247
model = torch .compile (model , ** self .torch_compile )
222
248
223
249
# Add attributes for compatibility with the pipeline
224
- model .use_kernel = self .use_kernel
225
- model .device = self .device
226
- model = model .to (self .device )
250
+ # model.use_kernel = self.use_kernel
251
+ # model.device = self.device
252
+ # model = model.to(self.device)
227
253
return model
0 commit comments