Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions iris/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,73 @@ def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, r
tensor.requires_grad_()
return tensor.reshape(size)

def full(self, size, fill_value, dtype=torch.int):
self.debug(f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}")
def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
"""
Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value.
The tensor is allocated on the Iris symmetric heap.

Args:
size (int...): a list, tuple, or torch.Size of integers defining the shape of the output tensor.
fill_value (Scalar): the value to fill the output tensor with.

Keyword Arguments:
out (Tensor, optional): the output tensor.
dtype (torch.dtype, optional): the desired data type of returned tensor.
Default: if None, uses a global default (see torch.set_default_dtype()).
layout (torch.layout, optional): the desired layout of returned Tensor.
Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter.
device (torch.device, optional): the desired device of returned tensor.
Default: if None, uses the current device for the default tensor type.
requires_grad (bool, optional): If autograd should record operations on the returned tensor.
Default: False.
"""
self.debug(
f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}"
)

# Infer dtype from fill_value if not provided
if dtype is None:
if isinstance(fill_value, (int, float)):
if isinstance(fill_value, float):
dtype = torch.get_default_dtype()
else:
dtype = torch.int64
else:
# For other types (like tensors), use their dtype
dtype = torch.get_default_dtype()

# Use current device if none specified
if device is None:
device = self.device

# Validate device compatibility with Iris
self.__throw_if_invalid_device(device)

# Parse size and calculate number of elements
size, num_elements = self.parse_size(size)
tensor = self.allocate(num_elements=num_elements, dtype=dtype)
tensor.fill_(fill_value)
return tensor.reshape(size)

# If out is provided, use it; otherwise allocate new tensor
if out is not None:
self.__throw_if_invalid_output_tensor(out, num_elements, dtype)
# Fill with the specified value
out.fill_(fill_value)
# Create a reshaped view of the out tensor
tensor = out.view(size)
else:
tensor = self.allocate(num_elements=num_elements, dtype=dtype)
# Fill with the specified value
tensor.fill_(fill_value)
# Reshape to the desired size
tensor = tensor.reshape(size)

# Apply the requested layout
tensor = self.__apply_layout(tensor, layout)

# Set requires_grad if specified
if requires_grad:
tensor.requires_grad_()

return tensor

def uniform(self, size, low=0.0, high=1.0, dtype=torch.float):
self.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}")
Expand Down
Loading