@@ -334,6 +334,7 @@ def extract_fields_from_class(
334334 arg_type : type [Arg ],
335335 out_type : type [Out ],
336336 auto_attribs : bool ,
337+ skip_fields : ty .Iterable [str ] = (),
337338) -> tuple [dict [str , Arg ], dict [str , Out ]]:
338339 """Extract the input and output fields from an existing class
339340
@@ -348,6 +349,8 @@ def extract_fields_from_class(
348349 auto_attribs : bool
349350 Whether to assume that all attribute annotations should be interpreted as
350351 fields or not
352+ skip_fields : Iterable[str], optional
353+ The names of attributes to skip when extracting the fields, by default ()
351354
352355 Returns
353356 -------
@@ -364,10 +367,15 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]:
364367 fields_dict = {}
365368 # Get fields defined in base classes if present
366369 for field in list_fields (klass ):
367- fields_dict [field .name ] = field
370+ if field .name not in skip_fields :
371+ fields_dict [field .name ] = field
368372 type_hints = ty .get_type_hints (klass )
369373 for atr_name in dir (klass ):
370- if atr_name in ["Task" , "Outputs" ] or atr_name .startswith ("__" ):
374+ if (
375+ atr_name in ["Task" , "Outputs" ]
376+ or atr_name in skip_fields
377+ or atr_name .startswith ("__" )
378+ ):
371379 continue
372380 try :
373381 atr = getattr (klass , atr_name )
@@ -394,7 +402,7 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]:
394402 )
395403 if auto_attribs :
396404 for atr_name , type_ in type_hints .items ():
397- if atr_name .startswith ("_" ):
405+ if atr_name .startswith ("_" ) or atr_name in skip_fields :
398406 continue
399407 if atr_name not in list (fields_dict ) + ["Task" , "Outputs" ]:
400408 fields_dict [atr_name ] = field_type (
0 commit comments