@@ -608,7 +608,7 @@ class TransformerEncoderLayer(Layer):
608608
609609 Examples:
610610
611- .. code-block:: python
611+ .. code-block:: pycon
612612
613613 >>> import paddle
614614 >>> from paddle.nn import TransformerEncoderLayer
@@ -620,7 +620,7 @@ class TransformerEncoderLayer(Layer):
620620 >>> encoder_layer = TransformerEncoderLayer(128, 2, 512)
621621 >>> enc_output = encoder_layer(enc_input, attn_mask)
622622 >>> print(enc_output.shape)
623- [2, 4, 128]
623+ paddle.Size( [2, 4, 128])
624624 """
625625
626626 activation : Layer
@@ -972,7 +972,7 @@ class TransformerDecoderLayer(Layer):
972972
973973 Examples:
974974
975- .. code-block:: python
975+ .. code-block:: pycon
976976
977977 >>> import paddle
978978 >>> from paddle.nn import TransformerDecoderLayer
@@ -986,12 +986,11 @@ class TransformerDecoderLayer(Layer):
986986 >>> # cross attention mask: [batch_size, n_head, tgt_len, src_len]
987987 >>> cross_attn_mask = paddle.rand((2, 2, 4, 6))
988988 >>> decoder_layer = TransformerDecoderLayer(128, 2, 512)
989- >>> output = decoder_layer(dec_input,
990- ... enc_output,
991- ... self_attn_mask,
992- ... cross_attn_mask)
989+ >>> output = decoder_layer(
990+ ... dec_input, enc_output, self_attn_mask, cross_attn_mask
991+ ... )
993992 >>> print(output.shape)
994- [2, 4, 128]
993+ paddle.Size( [2, 4, 128])
995994 """
996995
997996 normalize_before : bool
@@ -1498,7 +1497,7 @@ class Transformer(Layer):
14981497
14991498 Examples:
15001499
1501- .. code-block:: python
1500+ .. code-block:: pycon
15021501
15031502 >>> import paddle
15041503 >>> from paddle.nn import Transformer
@@ -1514,13 +1513,15 @@ class Transformer(Layer):
15141513 >>> # memory_mask: [batch_size, n_head, tgt_len, src_len]
15151514 >>> cross_attn_mask = paddle.rand((2, 2, 6, 4))
15161515 >>> transformer = Transformer(128, 2, 4, 4, 512)
1517- >>> output = transformer(enc_input,
1518- ... dec_input,
1519- ... enc_self_attn_mask,
1520- ... dec_self_attn_mask,
1521- ... cross_attn_mask)
1516+ >>> output = transformer(
1517+ ... enc_input,
1518+ ... dec_input,
1519+ ... enc_self_attn_mask,
1520+ ... dec_self_attn_mask,
1521+ ... cross_attn_mask,
1522+ ... )
15221523 >>> print(output.shape)
1523- [2, 6, 128]
1524+ paddle.Size( [2, 6, 128])
15241525 """
15251526
15261527 encoder : Layer
0 commit comments