@@ -91,7 +91,7 @@ class ExportArchive(BackendExportArchive):
91
91
92
92
**Note on resource tracking:**
93
93
94
- `ExportArchive` is able to automatically track all `tf .Variables` used
94
+ `ExportArchive` is able to automatically track all `keras .Variables` used
95
95
by its endpoints, so most of the time calling `.track(model)`
96
96
is not strictly required. However, if your model uses lookup layers such
97
97
as `IntegerLookup`, `StringLookup`, or `TextVectorization`,
@@ -104,9 +104,10 @@ class ExportArchive(BackendExportArchive):
104
104
105
105
def __init__ (self ):
106
106
super ().__init__ ()
107
- if backend .backend () not in ("tensorflow" , "jax" ):
107
+ if backend .backend () not in ("tensorflow" , "jax" , "torch" ):
108
108
raise NotImplementedError (
109
- "The export API is only compatible with JAX and TF backends."
109
+ "`ExportArchive` is only compatible with TensorFlow, JAX and "
110
+ "Torch backends."
110
111
)
111
112
112
113
self ._endpoint_names = []
@@ -141,8 +142,8 @@ def track(self, resource):
141
142
(`TextVectorization`, `IntegerLookup`, `StringLookup`)
142
143
are automatically tracked in `add_endpoint()`.
143
144
144
- Arguments :
145
- resource: A trackable TensorFlow resource.
145
+ Args :
146
+ resource: A trackable Keras resource, such as a layer or model .
146
147
"""
147
148
if isinstance (resource , layers .Layer ) and not resource .built :
148
149
raise ValueError (
@@ -334,12 +335,78 @@ def serving_fn(x):
334
335
self ._endpoint_names .append (name )
335
336
return decorated_fn
336
337
338
+ def track_and_add_endpoint (self , name , resource , input_signature , ** kwargs ):
339
+ """Track the variables and register a new serving endpoint.
340
+
341
+ This function combines the functionality of `track` and `add_endpoint`.
342
+ It tracks the variables of the `resource` (either a layer or a model)
343
+ and registers a serving endpoint using `resource.__call__`.
344
+
345
+ Args:
346
+ name: `str`. The name of the endpoint.
347
+ resource: A trackable Keras resource, such as a layer or model.
348
+ input_signature: Optional. Specifies the shape and dtype of `fn`.
349
+ Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,
350
+ `backend.KerasTensor`, or backend tensor (see below for an
351
+ example showing a `Functional` model with 2 input arguments). If
352
+ not provided, `fn` must be a `tf.function` that has been called
353
+ at least once. Defaults to `None`.
354
+ **kwargs: Additional keyword arguments:
355
+ - Specific to the JAX backend:
356
+ - `is_static`: Optional `bool`. Indicates whether `fn` is
357
+ static. Set to `False` if `fn` involves state updates
358
+ (e.g., RNG seeds).
359
+ - `jax2tf_kwargs`: Optional `dict`. Arguments for
360
+ `jax2tf.convert`. See [`jax2tf.convert`](
361
+ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
362
+ If `native_serialization` and `polymorphic_shapes` are
363
+ not provided, they are automatically computed.
364
+
365
+ """
366
+ if name in self ._endpoint_names :
367
+ raise ValueError (f"Endpoint name '{ name } ' is already taken." )
368
+ if not isinstance (resource , layers .Layer ):
369
+ raise ValueError (
370
+ "Invalid resource type. Expected an instance of a Keras "
371
+ "`Layer` or `Model`. "
372
+ f"Received: resource={ resource } (of type { type (resource )} )"
373
+ )
374
+ if not resource .built :
375
+ raise ValueError (
376
+ "The layer provided has not yet been built. "
377
+ "It must be built before export."
378
+ )
379
+ if backend .backend () != "jax" :
380
+ if "jax2tf_kwargs" in kwargs or "is_static" in kwargs :
381
+ raise ValueError (
382
+ "'jax2tf_kwargs' and 'is_static' are only supported with "
383
+ f"the jax backend. Current backend: { backend .backend ()} "
384
+ )
385
+
386
+ input_signature = tree .map_structure (_make_tensor_spec , input_signature )
387
+
388
+ if not hasattr (BackendExportArchive , "track_and_add_endpoint" ):
389
+ # Default behavior.
390
+ self .track (resource )
391
+ return self .add_endpoint (
392
+ name , resource .__call__ , input_signature , ** kwargs
393
+ )
394
+ else :
395
+ # Special case for the torch backend.
396
+ decorated_fn = BackendExportArchive .track_and_add_endpoint (
397
+ self , name , resource , input_signature , ** kwargs
398
+ )
399
+ self ._endpoint_signatures [name ] = input_signature
400
+ setattr (self ._tf_trackable , name , decorated_fn )
401
+ self ._endpoint_names .append (name )
402
+ return decorated_fn
403
+
337
404
def add_variable_collection (self , name , variables ):
338
405
"""Register a set of variables to be retrieved after reloading.
339
406
340
407
Arguments:
341
408
name: The string name for the collection.
342
- variables: A tuple/list/set of `tf .Variable` instances.
409
+ variables: A tuple/list/set of `keras .Variable` instances.
343
410
344
411
Example:
345
412
@@ -496,9 +563,6 @@ def export_saved_model(
496
563
):
497
564
"""Export the model as a TensorFlow SavedModel artifact for inference.
498
565
499
- **Note:** This feature is currently supported only with TensorFlow and
500
- JAX backends.
501
-
502
566
This method lets you export a model to a lightweight SavedModel artifact
503
567
that contains the model's forward pass only (its `call()` method)
504
568
and can be served via e.g. TensorFlow Serving. The forward pass is
@@ -527,6 +591,14 @@ def export_saved_model(
527
591
If `native_serialization` and `polymorphic_shapes` are not
528
592
provided, they are automatically computed.
529
593
594
+ **Note:** This feature is currently supported only with TensorFlow, JAX and
595
+ Torch backends. Support for the Torch backend is experimental.
596
+
597
+ **Note:** The dynamic shape feature is not yet supported with Torch
598
+ backend. As a result, you must fully define the shapes of the inputs using
599
+ `input_signature`. If `input_signature` is not provided, all instances of
600
+ `None` (such as the batch size) will be replaced with `1`.
601
+
530
602
Example:
531
603
532
604
```python
@@ -543,28 +615,29 @@ def export_saved_model(
543
615
`export()` method relies on `ExportArchive` internally.
544
616
"""
545
617
export_archive = ExportArchive ()
546
- export_archive .track (model )
547
- if isinstance (model , (Functional , Sequential )):
548
- if input_signature is None :
618
+ if input_signature is None :
619
+ if not model .built :
620
+ raise ValueError (
621
+ "The layer provided has not yet been built. "
622
+ "It must be built before export."
623
+ )
624
+ if isinstance (model , (Functional , Sequential )):
549
625
input_signature = tree .map_structure (
550
626
_make_tensor_spec , model .inputs
551
627
)
552
- if isinstance (input_signature , list ) and len (input_signature ) > 1 :
553
- input_signature = [input_signature ]
554
- export_archive .add_endpoint (
555
- "serve" , model .__call__ , input_signature , ** kwargs
556
- )
557
- else :
558
- if input_signature is None :
628
+ if isinstance (input_signature , list ) and len (input_signature ) > 1 :
629
+ input_signature = [input_signature ]
630
+ else :
559
631
input_signature = _get_input_signature (model )
560
- if not input_signature or not model ._called :
561
- raise ValueError (
562
- "The model provided has never called. "
563
- "It must be called at least once before export."
564
- )
565
- export_archive .add_endpoint (
566
- "serve" , model .__call__ , input_signature , ** kwargs
567
- )
632
+ if not input_signature or not model ._called :
633
+ raise ValueError (
634
+ "The model provided has never called. "
635
+ "It must be called at least once before export."
636
+ )
637
+
638
+ export_archive .track_and_add_endpoint (
639
+ "serve" , model , input_signature , ** kwargs
640
+ )
568
641
export_archive .write_out (filepath , verbose = verbose )
569
642
570
643
0 commit comments