Skip to content

Commit e4b7522

Browse files
add document of python_infer with depoly module (#885)
1 parent 2a1e85a commit e4b7522

File tree

4 files changed

+31
-5
lines changed

4 files changed

+31
-5
lines changed

deploy/python_infer/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from deploy.python_infer.base import Predictor
16+
from deploy.python_infer.pinn_predictor import PINNPredictor
17+
18+
# alias as PINNPredictor can be used in most cases
19+
GeneralPredictor = PINNPredictor
20+
21+
__all__ = [
22+
"Predictor",
23+
"PINNPredictor",
24+
"GeneralPredictor",
25+
]

docs/zh/api/depoly/python_infer.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Python_infer(Python 推理) 模块
2+
3+
::: deploy.python_infer
4+
handler: python
5+
options:
6+
members:
7+
- Predictor
8+
- GeneralPredictor
9+
- PINNPredictor
10+
show_root_heading: true
11+
heading_level: 3

mkdocs.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ nav:
8686
- NowcastNet: zh/examples/nowcastnet.md
8787
- DGMR: zh/examples/dgmr.md
8888
- API文档:
89-
- " ":
89+
- ppsci:
9090
- ppsci.arch: zh/api/arch.md
9191
- ppsci.autodiff: zh/api/autodiff.md
9292
- ppsci.constraint: zh/api/constraint.md
@@ -119,6 +119,8 @@ nav:
119119
- ppsci.visualize: zh/api/visualize.md
120120
- ppsci.experimental: zh/api/experimental.md
121121
- ppsci.probability: zh/api/probability.md
122+
- deploy:
123+
- deploy.python_infer: zh/api/depoly/python_infer.md
122124
- 使用指南: zh/user_guide.md
123125
- 开发与复现指南:
124126
- 开发指南: zh/development.md

ppsci/arch/cuboid_transformer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -912,13 +912,14 @@ def get_initial_z(self, final_mem, T_out):
912912
raise NotImplementedError
913913
return initial_z
914914

915-
def forward(self, x, verbose=False):
915+
def forward(self, x: "paddle.Tensor", verbose: bool = False) -> "paddle.Tensor":
916916
"""
917917
Args:
918-
x: Shape (B, T, H, W, C)
919-
verbose: if True, print intermediate shapes
918+
x (paddle.Tensor): Tensor with shape (B, T, H, W, C).
919+
verbose (bool): if True, print intermediate shapes.
920+
920921
Returns:
921-
out: The output Shape (B, T_out, H, W, C_out)
922+
out (paddle.Tensor): The output Shape (B, T_out, H, W, C_out)
922923
"""
923924

924925
x = self.concat_to_tensor(x, self.input_keys)

0 commit comments

Comments
 (0)