forked from Yancey0623/gotorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtensor_ctor.go
More file actions
51 lines (46 loc) · 1.38 KB
/
tensor_ctor.go
File metadata and controls
51 lines (46 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
package gotorch
// #cgo CFLAGS: -I ${SRCDIR}/cgotorch
// #cgo LDFLAGS: -L ${SRCDIR}/cgotorch -Wl,-rpath ${SRCDIR}/cgotorch -lcgotorch
// #cgo LDFLAGS: -L ${SRCDIR}/cgotorch/libtorch/lib -Wl,-rpath ${SRCDIR}/cgotorch/libtorch/lib -lc10 -ltorch -ltorch_cpu
// #include "cgotorch.h"
import "C"
import (
"unsafe"
)
// RandN returns a tensor filled with standard normal distribution, torch.randn
func RandN(shape []int64, requiresGrad bool) Tensor {
rg := 0
if requiresGrad {
rg = 1
}
var t C.Tensor
MustNil(unsafe.Pointer(C.RandN((*C.int64_t)(unsafe.Pointer(&shape[0])),
C.int64_t(len(shape)), C.int64_t(rg), &t)))
SetTensorFinalizer((*unsafe.Pointer)(&t))
return Tensor{(*unsafe.Pointer)(&t)}
}
// Rand torch.rand
func Rand(shape []int64, requireGrad bool) Tensor {
rg := 0
if requireGrad {
rg = 1
}
var t C.Tensor
MustNil(unsafe.Pointer(C.Rand((*C.int64_t)(unsafe.Pointer(&shape[0])),
C.int64_t(len(shape)), C.int64_t(rg), &t)))
SetTensorFinalizer((*unsafe.Pointer)(&t))
return Tensor{(*unsafe.Pointer)(&t)}
}
// Empty returns a tensor filled with random number, torch.empty
func Empty(shape []int64, requiresGrad bool) Tensor {
rg := 0
if requiresGrad {
rg = 1
}
var t C.Tensor
MustNil(
unsafe.Pointer(C.Empty((*C.int64_t)(unsafe.Pointer(&shape[0])),
C.int64_t(len(shape)), C.int64_t(rg), &t)))
SetTensorFinalizer((*unsafe.Pointer)(&t))
return Tensor{(*unsafe.Pointer)(&t)}
}