@@ -57,6 +57,37 @@ The following APIs are common across all frameworks.
5757 versions of the tensor across different microbatches
5858 (see ``StepOutput`` entry for more information).
5959
60+ The argument to ``smp.step `` decorated function should either be a tensor
61+ or an instance of list, tuple, dict or set for it to be split across
62+ microbatches. If your object doesn't fall into this category, you can make
63+ the library split your object, by implementing ``smp_slice `` method.
64+
65+ Below is an example of how to use it with PyTorch.
66+
67+ .. code :: python
68+
69+ class CustomType :
70+ def __init__ (self , tensor ):
71+ self .data = tensor
72+
73+ # The library will call this to invoke slicing on the object passing in total microbatches (num_mb)
74+ # and the current microbatch index (mb).
75+ def smp_slice (self , num_mb , mb , axis ):
76+ dim_size = list (self .data.size())[axis]
77+
78+ split_size = dim_size // num_mb
79+ sliced_tensor = self .data.narrow(axis, mb * split_size, split_size)
80+ return CustomType(sliced_tensor, self .other)
81+
82+ custom_obj = CustomType(torch.ones(4 ,))
83+
84+ @smp.step ()
85+ def step (custom_obj ):
86+ loss = model(custom_obj)
87+ model.backward(loss)
88+ return loss
89+
90+
6091 **Important:** ``smp.step`` splits the batch into microbatches, and
6192 executes everything inside the decorated function once per microbatch.
6293 This might affect the behavior of batch normalization, any operation
0 commit comments