Skip to content

Commit 4089f88

Browse files
authored
Match full from PyTorch (#137)
1 parent 103b9e2 commit 4089f88

File tree

2 files changed

+522
-5
lines changed

2 files changed

+522
-5
lines changed

iris/iris.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,73 @@ def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, r
418418

419419
return tensor
420420

421-
def full(self, size, fill_value, dtype=torch.int):
422-
self.debug(f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}")
421+
def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
422+
"""
423+
Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value.
424+
The tensor is allocated on the Iris symmetric heap.
425+
426+
Args:
427+
size (int...): a list, tuple, or torch.Size of integers defining the shape of the output tensor.
428+
fill_value (Scalar): the value to fill the output tensor with.
429+
430+
Keyword Arguments:
431+
out (Tensor, optional): the output tensor.
432+
dtype (torch.dtype, optional): the desired data type of returned tensor.
433+
Default: if None, uses a global default (see torch.set_default_dtype()).
434+
layout (torch.layout, optional): the desired layout of returned Tensor.
435+
Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter.
436+
device (torch.device, optional): the desired device of returned tensor.
437+
Default: if None, uses the current device for the default tensor type.
438+
requires_grad (bool, optional): If autograd should record operations on the returned tensor.
439+
Default: False.
440+
"""
441+
self.debug(
442+
f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}"
443+
)
444+
445+
# Infer dtype from fill_value if not provided
446+
if dtype is None:
447+
if isinstance(fill_value, (int, float)):
448+
if isinstance(fill_value, float):
449+
dtype = torch.get_default_dtype()
450+
else:
451+
dtype = torch.int64
452+
else:
453+
# For other types (like tensors), use their dtype
454+
dtype = torch.get_default_dtype()
455+
456+
# Use current device if none specified
457+
if device is None:
458+
device = self.device
459+
460+
# Validate device compatibility with Iris
461+
self.__throw_if_invalid_device(device)
462+
463+
# Parse size and calculate number of elements
423464
size, num_elements = self.parse_size(size)
424-
tensor = self.allocate(num_elements=num_elements, dtype=dtype)
425-
tensor.fill_(fill_value)
426-
return tensor.reshape(size)
465+
466+
# If out is provided, use it; otherwise allocate new tensor
467+
if out is not None:
468+
self.__throw_if_invalid_output_tensor(out, num_elements, dtype)
469+
# Fill with the specified value
470+
out.fill_(fill_value)
471+
# Create a reshaped view of the out tensor
472+
tensor = out.view(size)
473+
else:
474+
tensor = self.allocate(num_elements=num_elements, dtype=dtype)
475+
# Fill with the specified value
476+
tensor.fill_(fill_value)
477+
# Reshape to the desired size
478+
tensor = tensor.reshape(size)
479+
480+
# Apply the requested layout
481+
tensor = self.__apply_layout(tensor, layout)
482+
483+
# Set requires_grad if specified
484+
if requires_grad:
485+
tensor.requires_grad_()
486+
487+
return tensor
427488

428489
def uniform(self, size, low=0.0, high=1.0, dtype=torch.float):
429490
self.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}")

0 commit comments

Comments
 (0)