Skip to content

dimension in GetSubMask #23

@ichenjia

Description

@ichenjia

def GetSubMask(s):
len_s = tf.shape(s)[1]
bs = tf.shape(s)[:1]
mask = K.cumsum(tf.eye(len_s, batch_shape=bs), 1)
return mask

if the input is (5,4,3)

wouldn't tf.eye here creates a lower triangle tensor of 5, 4, 4 instead of 5,4,3 because of [:1]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions