Pyright is stubborn about using union types in Pytorch #1409
-
In line 15 of the following working code, Pyright produces a type error: import torch.nn as nn
import torch.nn.functional as F
from typing import List
class iter_error(nn.Module):
def __init__(self, layer_sizes : List[int]):
super().__init__()
self.layers : nn.ModuleList = nn.ModuleList() # explicit annotation
self.layer_sizes = layer_sizes
for h_in, h_out in zip(self.layer_sizes[:-1], self.layer_sizes[1:]):
self.layers.append(nn.Linear(h_in, h_out))
def forward(self, x):
for i, layer in enumerate(self.layers[:-1]): # type error is signaled inside this enumerate
x = F.relu(layer(x))
x = self.layers[-1](x)
return x The error:
It looks like Pyright considers |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The expression class ModuleList(Module):
...
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, 'ModuleList']: ... The problem appears to be that @overload
def __getitem__(self, idx: int) -> Module: ...
@overload
def __getitem__(self, idx: slice) -> ModuleList: ... You may want to file a bug report or submit a PR in the pytorch project. |
Beta Was this translation helpful? Give feedback.
The expression
self.layers
is evaluated as typeModuleList
, but the expressionself.layers[:-1]
has the typeModuleList | Module
. This comes directly from theModuleList.__getitem__
method type declaration:The problem appears to be that
__getitem__
is missing an@overload
declaration. If you add the following overloads to the pytorch sources, it will work as expected.You may want to file a bug report or submit a PR in the pytorch project.