@@ -383,24 +383,26 @@ def _create_speculation_config(neuron_config: NxDNeuronConfig) -> NxDNeuronConfi
383383        spec_neuron_config .batch_size  =  neuron_config .tkg_batch_size 
384384        return  spec_neuron_config 
385385
386-     @staticmethod  
387-     def  create_graph_builders (model_cls , config , neuron_config ):
386+     @classmethod  
387+     def  create_graph_builders (cls , config , neuron_config ):
388+         if  cls ._model_cls  is  None :
389+             raise  SystemError (f"No underlying model class defined for { cls }  ." )
388390        graph_builders  =  {}
389391        ctx_neuron_config  =  NxDModelForCausalLM ._create_context_encoding_config (neuron_config )
390392        graph_builders ["context_encoding" ] =  NxDDecoderBuilder (
391393            config = config ,
392394            neuron_config = ctx_neuron_config ,
393395            max_tokens = ctx_neuron_config .max_context_length ,
394396            active_tokens = ctx_neuron_config .max_context_length ,
395-             model_cls = model_cls ,
397+             model_cls = cls . _model_cls ,
396398        )
397399        tkg_neuron_config  =  NxDModelForCausalLM ._create_token_generation_config (neuron_config )
398400        graph_builders ["token_generation" ] =  NxDDecoderBuilder (
399401            config = config ,
400402            neuron_config = tkg_neuron_config ,
401403            max_tokens = tkg_neuron_config .sequence_length ,
402404            active_tokens = 1 ,
403-             model_cls = model_cls ,
405+             model_cls = cls . _model_cls ,
404406            priority_model_idx = 0 ,  # to turn on weight layout optimization 
405407        )
406408        if  neuron_config .speculation_length  >  0 :
@@ -410,7 +412,7 @@ def create_graph_builders(model_cls, config, neuron_config):
410412                neuron_config = spec_neuron_config ,
411413                max_tokens = spec_neuron_config .sequence_length ,
412414                active_tokens = spec_neuron_config .speculation_length ,
413-                 model_cls = model_cls ,
415+                 model_cls = cls . _model_cls ,
414416                priority_model_idx = 0 ,  # to turn on weight layout optimization 
415417            )
416418        return  graph_builders 
@@ -617,9 +619,7 @@ def _from_pretrained(
617619                traced_model  =  torch .jit .load (os .path .join (tmpdir , cls .COMPILED_MODEL_FILE_NAME ))
618620        else :
619621            traced_model  =  torch .jit .load (os .path .join (model_id , cls .COMPILED_MODEL_FILE_NAME ))
620-         graph_builders  =  NxDModelForCausalLM .create_graph_builders (
621-             cls ._model_cls , config = config , neuron_config = neuron_config 
622-         )
622+         graph_builders  =  NxDModelForCausalLM .create_graph_builders (config = config , neuron_config = neuron_config )
623623        model  =  cls (
624624            config = config ,
625625            neuron_config = neuron_config ,
@@ -647,7 +647,7 @@ def _export(
647647        force_download : bool  |  None  =  False ,
648648        local_files_only : bool  |  None  =  False ,
649649        trust_remote_code : bool  |  None  =  False ,
650-         load_weights : bool  =  False ,
650+         load_weights : bool  |   None   =  False ,
651651        ** kwargs ,
652652    ) ->  "NeuronModelForCausalLM" :
653653        if  len (kwargs ) >  0 :
@@ -675,7 +675,6 @@ def _export(
675675        if  hasattr (config , "head_dim" ) and  config .head_dim  is  None :
676676            config .head_dim  =  config .hidden_size  //  config .num_attention_heads 
677677        graph_builders  =  cls .create_graph_builders (
678-             model_cls = cls ._model_cls ,
679678            config = config ,
680679            neuron_config = neuron_config ,
681680        )
0 commit comments