@@ -384,69 +384,36 @@ def _create_speculation_config(neuron_config: NxDNeuronConfig) -> NxDNeuronConfi
384384 return spec_neuron_config
385385
386386 @staticmethod
387- def _create_context_encoding_builder (model_cls , config , neuron_config ):
387+ def create_graph_builders (model_cls , config , neuron_config ):
388+ graph_builders = {}
388389 ctx_neuron_config = NxDModelForCausalLM ._create_context_encoding_config (neuron_config )
389-
390- return NxDDecoderBuilder (
390+ graph_builders ["context_encoding" ] = NxDDecoderBuilder (
391391 config = config ,
392392 neuron_config = ctx_neuron_config ,
393393 max_tokens = ctx_neuron_config .max_context_length ,
394394 active_tokens = ctx_neuron_config .max_context_length ,
395395 model_cls = model_cls ,
396- tag = CONTEXT_ENCODING_MODEL_TAG ,
397396 )
398-
399- @staticmethod
400- def _create_token_generation_builder (model_cls , config , neuron_config , enable_wlt_optimization : bool = True ):
401397 tkg_neuron_config = NxDModelForCausalLM ._create_token_generation_config (neuron_config )
402-
403- return NxDDecoderBuilder (
398+ graph_builders ["token_generation" ] = NxDDecoderBuilder (
404399 config = config ,
405400 neuron_config = tkg_neuron_config ,
406401 max_tokens = tkg_neuron_config .sequence_length ,
407402 active_tokens = 1 ,
408403 model_cls = model_cls ,
409- tag = TOKEN_GENERATION_MODEL_TAG ,
410- priority_model_idx = 0 if enable_wlt_optimization else None , # to turn on weight layout optimization
411- )
412-
413- @staticmethod
414- def _create_speculation_builder (model_cls , config , neuron_config ):
415- spec_neuron_config = NxDModelForCausalLM ._create_speculation_config (neuron_config )
416-
417- return NxDDecoderBuilder (
418- config = config ,
419- neuron_config = spec_neuron_config ,
420- max_tokens = spec_neuron_config .sequence_length ,
421- active_tokens = spec_neuron_config .speculation_length ,
422- model_cls = model_cls ,
423- tag = SPECULATION_MODEL_TAG ,
424404 priority_model_idx = 0 , # to turn on weight layout optimization
425405 )
426-
427- @staticmethod
428- def create_model_builders (model_cls , config , neuron_config ):
429- model_builders = [
430- NxDModelForCausalLM ._create_context_encoding_builder (
431- model_cls ,
432- config ,
433- neuron_config ,
434- ),
435- NxDModelForCausalLM ._create_token_generation_builder (
436- model_cls ,
437- config ,
438- neuron_config ,
439- ),
440- ]
441406 if neuron_config .speculation_length > 0 :
442- model_builders .append (
443- NxDModelForCausalLM ._create_speculation_builder (
444- model_cls ,
445- config ,
446- neuron_config ,
447- )
407+ spec_neuron_config = NxDModelForCausalLM ._create_speculation_config (neuron_config )
408+ graph_builders ["speculation_model" ] = NxDDecoderBuilder (
409+ config = config ,
410+ neuron_config = spec_neuron_config ,
411+ max_tokens = spec_neuron_config .sequence_length ,
412+ active_tokens = spec_neuron_config .speculation_length ,
413+ model_cls = model_cls ,
414+ priority_model_idx = 0 , # to turn on weight layout optimization
448415 )
449- return model_builders
416+ return graph_builders
450417
451418 def forward (
452419 self ,
@@ -650,7 +617,7 @@ def _from_pretrained(
650617 traced_model = torch .jit .load (os .path .join (tmpdir , cls .COMPILED_MODEL_FILE_NAME ))
651618 else :
652619 traced_model = torch .jit .load (os .path .join (model_id , cls .COMPILED_MODEL_FILE_NAME ))
653- model_builders = NxDModelForCausalLM .create_model_builders (
620+ model_builders = NxDModelForCausalLM .create_graph_builders (
654621 cls ._model_cls , config = config , neuron_config = neuron_config
655622 )
656623 model = cls (
@@ -707,7 +674,7 @@ def _export(
707674 # Evaluate head_dim if it is defined but set to null (like in Mixtral for transformers 4.54+)
708675 if hasattr (config , "head_dim" ) and config .head_dim is None :
709676 config .head_dim = config .hidden_size // config .num_attention_heads
710- model_builders = cls .create_model_builders (
677+ model_builders = cls .create_graph_builders (
711678 model_cls = cls ._model_cls ,
712679 config = config ,
713680 neuron_config = neuron_config ,
0 commit comments